1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project
19 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
21 #include "mlir/IR/Attributes.h" // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/IR/Operation.h" // from @llvm-project
25 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
27 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
28 #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/utils.h"
29
30 namespace mlir {
31 namespace kernel_gen {
32 namespace tf_framework {
33 namespace {
34
35 using transforms::CreateOrFindGlobalStringConstant;
36 using transforms::GetGlobalName;
37 using transforms::GetOrInsertLLVMFunction;
38
39 static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
40 static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
41 static constexpr StringRef kCInterfaceReportError =
42 "_mlir_ciface_tf_report_error";
43 static constexpr StringRef kCInterfaceJITCompile =
44 "_mlir_ciface_tf_jit_compile";
45 static constexpr StringRef kCInterfaceJITExecute =
46 "_mlir_ciface_tf_jit_execute";
47 static constexpr StringRef kJITCodeGlobalBaseName = "jit_module_code";
48 static constexpr StringRef kErrorMessageGlobalBaseName = "error_message";
49
50 /// Base class for patterns converting TF Framework ops to function calls.
51 template <typename OpTy>
52 class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
53 public:
54 using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
55
56 protected:
57 virtual StringRef GetFuncName() const = 0;
58 virtual Type GetFuncType() const = 0;
59
ConvertArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter,std::function<Value (Attribute)> create_element) const60 std::pair<Value, Value> ConvertArrayAttrToStackAllocatedArray(
61 Location loc, Type size_ty, Type element_ty,
62 llvm::Optional<ArrayAttr> attr, ConversionPatternRewriter *rewriter,
63 std::function<Value(Attribute)> create_element) const {
64 Type element_ptr_ty = LLVM::LLVMPointerType::get(element_ty);
65
66 // If the attribute is missing or empty, set the element count to 0 and
67 // return NULL.
68 if (!attr.has_value() || attr.getValue().empty()) {
69 Value zero = rewriter->create<LLVM::ConstantOp>(
70 loc, size_ty, rewriter->getIntegerAttr(size_ty, 0));
71 Value null_ptr = rewriter->create<LLVM::NullOp>(loc, element_ptr_ty);
72 return std::make_pair(zero, null_ptr);
73 }
74
75 // Allocate array to store the elements.
76 auto &array_attr = attr.getValue();
77 Value array_size = rewriter->create<LLVM::ConstantOp>(
78 loc, size_ty, rewriter->getIntegerAttr(size_ty, array_attr.size()));
79 Value array_ptr = rewriter->create<LLVM::AllocaOp>(
80 loc, element_ptr_ty, array_size, /*alignment=*/0);
81 for (auto &e : llvm::enumerate(array_attr)) {
82 Value index = rewriter->create<LLVM::ConstantOp>(
83 loc, size_ty, rewriter->getIntegerAttr(size_ty, e.index()));
84 Value element_ptr =
85 rewriter->create<LLVM::GEPOp>(loc, element_ptr_ty, array_ptr, index);
86 Value element = create_element(e.value());
87 rewriter->create<LLVM::StoreOp>(loc, element, element_ptr);
88 }
89 return std::make_pair(array_size, array_ptr);
90 }
91
ConvertIntegerArrayAttrToStackAllocatedArray(Location loc,Type size_ty,Type element_ty,llvm::Optional<ArrayAttr> attr,ConversionPatternRewriter * rewriter) const92 std::pair<Value, Value> ConvertIntegerArrayAttrToStackAllocatedArray(
93 Location loc, Type size_ty, Type element_ty,
94 llvm::Optional<ArrayAttr> attr,
95 ConversionPatternRewriter *rewriter) const {
96 assert(size_ty.isa<IntegerType>() && "expect integer size type");
97 assert(element_ty.isa<IntegerType>() && "expect integer element type");
98 return ConvertArrayAttrToStackAllocatedArray(
99 loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) {
100 return rewriter->create<LLVM::ConstantOp>(
101 loc, element_ty,
102 rewriter->getIntegerAttr(element_ty,
103 attr.cast<IntegerAttr>().getInt()));
104 });
105 }
106 };
107
108 class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
109 public:
110 using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
111
matchAndRewrite(TFAllocOp tf_alloc_op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const112 LogicalResult matchAndRewrite(
113 TFAllocOp tf_alloc_op, OpAdaptor adaptor,
114 ConversionPatternRewriter &rewriter) const override {
115 mlir::Operation *op = tf_alloc_op.getOperation();
116 Location loc = op->getLoc();
117 MemRefType memref_type = tf_alloc_op.getType();
118
119 // Get memref descriptor sizes.
120 SmallVector<Value, 4> sizes;
121 SmallVector<Value, 4> strides;
122 Value sizeBytes;
123 getMemRefDescriptorSizes(loc, memref_type,
124 llvm::to_vector<4>(adaptor.dyn_sizes()), rewriter,
125 sizes, strides, sizeBytes);
126 // Get number of elements.
127 Value num_elements = getNumElements(loc, sizes, rewriter);
128 // Get element size.
129 Value element_size =
130 getSizeInBytes(loc, memref_type.getElementType(), rewriter);
131
132 // Convert `output_index` or set it to -1 if the attribute is missing.
133 Type llvmInt32Type = IntegerType::get(rewriter.getContext(), 32);
134 Value output_index = rewriter.create<LLVM::ConstantOp>(
135 loc, llvmInt32Type,
136 rewriter.getI32IntegerAttr(tf_alloc_op.output_index().has_value()
137 ? tf_alloc_op.output_index().getValue()
138 : -1));
139
140 // Convert `candidate_input_indices`.
141 auto candidates_count_and_ptr =
142 ConvertIntegerArrayAttrToStackAllocatedArray(
143 loc, rewriter.getI32Type(), rewriter.getI32Type(),
144 tf_alloc_op.input_indices(), &rewriter);
145
146 // Insert function call.
147 FlatSymbolRefAttr tf_func_ref =
148 GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
149 Value allocated_byte_ptr =
150 rewriter
151 .create<LLVM::CallOp>(
152 loc, getVoidPtrType(), tf_func_ref,
153 llvm::makeArrayRef({adaptor.ctx(), num_elements, element_size,
154 output_index,
155 candidates_count_and_ptr.first,
156 candidates_count_and_ptr.second}))
157 .getResult();
158
159 MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
160 loc, rewriter, memref_type, allocated_byte_ptr, sizes);
161
162 // Return the final value of the descriptor.
163 rewriter.replaceOp(op, {memRefDescriptor});
164 return success();
165 }
166
167 protected:
GetFuncName() const168 StringRef GetFuncName() const override { return kCInterfaceAlloc; }
169
GetFuncType() const170 Type GetFuncType() const override {
171 Type llvm_i32_type = IntegerType::get(getDialect().getContext(), 32);
172 Type llvm_i32_ptr_type = LLVM::LLVMPointerType::get(llvm_i32_type);
173 Type llvm_void_ptr_type = getVoidPtrType();
174 return LLVM::LLVMFunctionType::get(
175 llvm_void_ptr_type,
176 llvm::makeArrayRef(
177 {/*void* op_kernel_ctx*/ llvm_void_ptr_type,
178 /*size_t num_elements*/ getIndexType(),
179 /*size_t element_size*/ getIndexType(),
180 /*int32_t output_index*/ llvm_i32_type,
181 /*int32_t num_candidates*/ llvm_i32_type,
182 /*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}));
183 }
184
185 private:
186 // TODO(pifon): Remove strides computation.
CreateMemRefDescriptor(Location loc,ConversionPatternRewriter & rewriter,MemRefType memref_type,Value allocated_byte_ptr,ArrayRef<Value> sizes) const187 MemRefDescriptor CreateMemRefDescriptor(Location loc,
188 ConversionPatternRewriter &rewriter,
189 MemRefType memref_type,
190 Value allocated_byte_ptr,
191 ArrayRef<Value> sizes) const {
192 auto memref_desc = MemRefDescriptor::undef(
193 rewriter, loc, typeConverter->convertType(memref_type));
194
195 // TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
196 Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
197 loc, getElementPtrType(memref_type), allocated_byte_ptr);
198 memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
199 memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
200 memref_desc.setConstantOffset(rewriter, loc, 0);
201
202 if (memref_type.getRank() == 0) {
203 return memref_desc;
204 }
205
206 // Compute strides and populate descriptor `size` and `stride` fields.
207 Value stride_carried = createIndexConstant(rewriter, loc, 1);
208 for (int pos = sizes.size() - 1; pos >= 0; --pos) {
209 Value size = sizes[pos];
210 memref_desc.setSize(rewriter, loc, pos, size);
211 memref_desc.setStride(rewriter, loc, pos, stride_carried);
212 // Update stride
213 if (pos > 0) {
214 stride_carried =
215 rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
216 }
217 }
218 return memref_desc;
219 }
220 };
221
222 class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
223 public:
224 using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
225
matchAndRewrite(TFDeallocOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const226 LogicalResult matchAndRewrite(
227 TFDeallocOp op, OpAdaptor adaptor,
228 ConversionPatternRewriter &rewriter) const override {
229 // TODO(herhut) Support unranked memrefs.
230 if (!op.memref().getType().isa<MemRefType>()) return failure();
231 MemRefDescriptor memref(adaptor.memref());
232
233 Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
234 op.getLoc(), getVoidPtrType(),
235 memref.allocatedPtr(rewriter, op.getLoc()));
236
237 // Insert function call.
238 FlatSymbolRefAttr tf_func_ref =
239 GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
240 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
241 op, llvm::None, tf_func_ref,
242 llvm::makeArrayRef({adaptor.ctx(), allocated_bytes_ptr}));
243 return success();
244 }
245
246 protected:
GetFuncName() const247 StringRef GetFuncName() const override { return kCInterfaceDealloc; }
GetFuncType() const248 Type GetFuncType() const override {
249 return LLVM::LLVMFunctionType::get(getVoidType(),
250 {getVoidPtrType(), getVoidPtrType()});
251 }
252 };
253
254 class JITCompileFromStrOpConverter
255 : public ConvertToLLVMCallOpPattern<JITCompileFromStrOp> {
256 using ConvertToLLVMCallOpPattern<
257 JITCompileFromStrOp>::ConvertToLLVMCallOpPattern;
258
matchAndRewrite(JITCompileFromStrOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const259 LogicalResult matchAndRewrite(
260 JITCompileFromStrOp op, OpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter) const override {
262 if (adaptor.ctx() == nullptr) return failure();
263 auto loc = op.getLoc();
264 std::string zero_terminated_code = op.code().str() + '\00';
265 Value jit_module_code = CreateOrFindGlobalStringConstant(
266 loc, GetGlobalName(kJITCodeGlobalBaseName, zero_terminated_code),
267 zero_terminated_code, &rewriter);
268 std::pair<Value, Value> tile_sizes =
269 ConvertIntegerArrayAttrToStackAllocatedArray(loc, rewriter.getI64Type(),
270 rewriter.getI64Type(),
271 op.tileSizes(), &rewriter);
272 std::pair<Value, Value> unroll_factors =
273 ConvertIntegerArrayAttrToStackAllocatedArray(
274 loc, rewriter.getI64Type(), rewriter.getI64Type(),
275 op.unrollFactors(), &rewriter);
276 Value max_supported_rank = rewriter.create<LLVM::ConstantOp>(
277 loc, rewriter.getI64Type(), op.maxSupportedRankAttr());
278 Value enable_ftz = rewriter.create<LLVM::ConstantOp>(
279 loc, rewriter.getI1Type(), op.enableFtzAttr());
280 Value index_64bit = rewriter.create<LLVM::ConstantOp>(
281 loc, rewriter.getI1Type(), op.index64BitAttr());
282 Value cpu_codegen = rewriter.create<LLVM::ConstantOp>(
283 loc, rewriter.getI1Type(), op.cpuCodegenAttr());
284 FlatSymbolRefAttr tf_func_ref =
285 GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
286 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
287 op, getVoidPtrType(), tf_func_ref,
288 llvm::makeArrayRef({adaptor.ctx(), jit_module_code, tile_sizes.first,
289 tile_sizes.second, unroll_factors.first,
290 unroll_factors.second, max_supported_rank,
291 enable_ftz, index_64bit, cpu_codegen}));
292 return success();
293 }
294
295 protected:
GetFuncName() const296 StringRef GetFuncName() const override { return kCInterfaceJITCompile; }
297
GetFuncType() const298 Type GetFuncType() const override {
299 auto i8_ptr_ty =
300 LLVM::LLVMPointerType::get(IntegerType::get(getContext(), 8));
301 auto i64_ty = IntegerType::get(getContext(), 64);
302 Type i64_ptr_ty = LLVM::LLVMPointerType::get(i64_ty);
303 auto i1_ty = IntegerType::get(getContext(), 1);
304 return LLVM::LLVMFunctionType::get(
305 getVoidPtrType(), {/*void* op_kernel_ctx*/ getVoidPtrType(),
306 /*char* code*/ i8_ptr_ty,
307 /*int64_t num_tile_sizes*/ i64_ty,
308 /*int64_t* tile_sizes_ptr*/ i64_ptr_ty,
309 /*int64_t num_unroll_factors*/ i64_ty,
310 /*int64_t* unroll_factors_ptr*/ i64_ptr_ty,
311 /*int64_t max_supported_rank*/ i64_ty,
312 /*bool enable_ftz*/ i1_ty,
313 /*bool index_64bit*/ i1_ty,
314 /*bool cpu_codegen*/ i1_ty});
315 }
316 };
317
318 class JITExecuteOpConverter : public ConvertToLLVMCallOpPattern<JITExecuteOp> {
319 public:
320 using ConvertToLLVMCallOpPattern<JITExecuteOp>::ConvertToLLVMCallOpPattern;
321
matchAndRewrite(JITExecuteOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const322 LogicalResult matchAndRewrite(
323 JITExecuteOp op, OpAdaptor adaptor,
324 ConversionPatternRewriter &rewriter) const override {
325 // The TF context must be known for a successful lowering.
326 if (adaptor.ctx() == nullptr || op.operands().empty()) {
327 return failure();
328 }
329
330 // Allocate result on stack.
331 auto loc = op.getLoc();
332 Type result_ty =
333 getTypeConverter()->convertType(op->getResultTypes().front());
334 Type result_ptr_ty = LLVM::LLVMPointerType::get(result_ty);
335 Type i64_ty = rewriter.getI64Type();
336 Value one = rewriter.create<LLVM::ConstantOp>(
337 loc, i64_ty, rewriter.getI64IntegerAttr(1));
338 auto result_ptr =
339 rewriter.create<LLVM::AllocaOp>(loc, result_ptr_ty, one, llvm::None);
340 Type void_ptr_ty = getVoidPtrType();
341 auto result_void_ptr =
342 rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, result_ptr);
343
344 // Pass the buffer arguments as a stack-allocated array.
345 Type arg_ptr_ty =
346 LLVM::LLVMPointerType::get(adaptor.operands().front().getType());
347 Value num_args = rewriter.create<LLVM::ConstantOp>(
348 loc, i64_ty,
349 rewriter.getI64IntegerAttr(
350 static_cast<int64_t>(adaptor.operands().size())));
351 Value args_ptr = rewriter.create<LLVM::AllocaOp>(loc, arg_ptr_ty, num_args,
352 /*alignment=*/0);
353 for (const auto &it : llvm::enumerate(adaptor.operands())) {
354 Value index = rewriter.create<LLVM::ConstantOp>(
355 loc, i64_ty, rewriter.getI64IntegerAttr(it.index()));
356 Value element_ptr =
357 rewriter.create<LLVM::GEPOp>(loc, arg_ptr_ty, args_ptr, index);
358 rewriter.create<LLVM::StoreOp>(loc, it.value(), element_ptr);
359 }
360 auto args_void_ptr =
361 rewriter.create<LLVM::BitcastOp>(loc, void_ptr_ty, args_ptr);
362
363 // Materialize runtime call.
364 FlatSymbolRefAttr tf_func_ref =
365 GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
366 rewriter.create<LLVM::CallOp>(
367 loc, llvm::None, tf_func_ref,
368 ValueRange{adaptor.ctx(), adaptor.callable(), result_void_ptr, num_args,
369 args_void_ptr});
370
371 // Copy result (including the descriptor) to a stack-allocated buffer and
372 // free the old descriptor.
373 llvm::SmallVector<Value, 1> final_result = {
374 rewriter.create<LLVM::LoadOp>(loc, result_ptr)};
375 if (failed(copyUnrankedDescriptors(rewriter, loc, op->getResultTypes(),
376 final_result,
377 /*toDynamic=*/false))) {
378 return failure();
379 }
380
381 rewriter.replaceOp(op, final_result.front());
382 return success();
383 }
384
385 protected:
GetFuncName() const386 StringRef GetFuncName() const override { return kCInterfaceJITExecute; }
387
GetFuncType() const388 Type GetFuncType() const override {
389 auto i64_ty = IntegerType::get(getContext(), 64);
390 auto void_ptr_ty = getVoidPtrType();
391 return LLVM::LLVMFunctionType::get(getVoidType(),
392 {/*void* op_kernel_ctx*/ void_ptr_ty,
393 /*void* callable*/ void_ptr_ty,
394 /*void* result*/ void_ptr_ty,
395 /*int64_t num_args*/ i64_ty,
396 /*void* args_ptr*/ void_ptr_ty});
397 }
398 };
399
400 class ReportErrorOpConverter
401 : public ConvertToLLVMCallOpPattern<ReportErrorOp> {
402 public:
403 using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
404
matchAndRewrite(ReportErrorOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const405 LogicalResult matchAndRewrite(
406 ReportErrorOp op, OpAdaptor adaptor,
407 ConversionPatternRewriter &rewriter) const override {
408 Location loc = op.getLoc();
409 auto module = op->getParentOfType<ModuleOp>();
410 Value message_constant =
411 GenerateErrorMessageConstant(loc, module, adaptor.msg(), rewriter);
412
413 // Insert function call.
414 FlatSymbolRefAttr tf_func_ref =
415 GetOrInsertLLVMFunction(GetFuncName(), GetFuncType(), op, &rewriter);
416 Value error_code = rewriter.create<LLVM::ConstantOp>(
417 loc, typeConverter->convertType(rewriter.getI32Type()),
418 adaptor.error_codeAttr());
419 rewriter.replaceOpWithNewOp<LLVM::CallOp>(
420 op, llvm::None, tf_func_ref,
421 llvm::makeArrayRef({adaptor.ctx(), error_code, message_constant}));
422 return success();
423 }
424
425 protected:
GetFuncName() const426 StringRef GetFuncName() const override { return kCInterfaceReportError; }
GetFuncType() const427 Type GetFuncType() const override {
428 MLIRContext *ctx = &getTypeConverter()->getContext();
429 auto i8_ptr_type = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
430 auto i32_type = IntegerType::get(ctx, 32);
431 return LLVM::LLVMFunctionType::get(
432 getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type});
433 }
434
435 private:
436 // Generates an LLVM IR dialect global that contains the name of the given
437 // kernel function as a C string, and returns a pointer to its beginning.
GenerateErrorMessageConstant(Location loc,Operation * module,StringRef message,OpBuilder & builder) const438 Value GenerateErrorMessageConstant(Location loc, Operation *module,
439 StringRef message,
440 OpBuilder &builder) const {
441 std::string err_str;
442 llvm::raw_string_ostream err_stream(err_str);
443 err_stream << message;
444 if (!loc.isa<UnknownLoc>()) {
445 err_stream << " at ";
446 loc.print(err_stream);
447 }
448 err_stream << '\00';
449 StringRef generated_error(err_stream.str());
450 return CreateOrFindGlobalStringConstant(
451 loc, GetGlobalName(kErrorMessageGlobalBaseName, generated_error),
452 generated_error, &builder);
453 }
454 };
455
456 class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
457 public:
458 using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
459
matchAndRewrite(NullContextOp op,OpAdaptor,ConversionPatternRewriter & rewriter) const460 LogicalResult matchAndRewrite(
461 NullContextOp op, OpAdaptor /*adaptor*/,
462 ConversionPatternRewriter &rewriter) const override {
463 rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
464 return success();
465 }
466 };
467
468 class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
469 public:
470 using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
471
matchAndRewrite(NullMemRefOp null_memref_op,OpAdaptor,ConversionPatternRewriter & rewriter) const472 LogicalResult matchAndRewrite(
473 NullMemRefOp null_memref_op, OpAdaptor /*adaptor*/,
474 ConversionPatternRewriter &rewriter) const override {
475 Location loc = null_memref_op->getLoc();
476 LLVMTypeConverter type_converter = *getTypeConverter();
477 mlir::Operation *op = null_memref_op.getOperation();
478
479 auto shaped_result_type = null_memref_op.getType().cast<BaseMemRefType>();
480 unsigned address_space = shaped_result_type.getMemorySpaceAsInt();
481
482 Type elem_type = shaped_result_type.getElementType();
483 Type llvm_elem_type = type_converter.convertType(elem_type);
484
485 Value zero = createIndexConstant(rewriter, loc, 0);
486 if (auto result_type = null_memref_op.getType().dyn_cast<MemRefType>()) {
487 // Set all dynamic sizes to 1 and compute fake strides.
488 SmallVector<Value, 4> dyn_sizes(result_type.getNumDynamicDims(),
489 createIndexConstant(rewriter, loc, 1));
490 SmallVector<Value, 4> sizes, strides;
491 Value sizeBytes;
492 getMemRefDescriptorSizes(loc, result_type, dyn_sizes, rewriter, sizes,
493 strides, sizeBytes);
494
495 // Prepare packed args [allocatedPtr, alignedPtr, offset, sizes, strides]
496 // to create a memref descriptor.
497 Value null = rewriter.create<LLVM::NullOp>(
498 loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
499 SmallVector<Value, 12> packed_values{null, null, zero};
500 packed_values.append(sizes);
501 packed_values.append(strides);
502
503 rewriter.replaceOp(
504 op, MemRefDescriptor::pack(rewriter, loc, type_converter, result_type,
505 packed_values));
506 return success();
507 }
508
509 auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
510 Type llvm_result_type = type_converter.convertType(result_type);
511
512 auto desc =
513 UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
514 desc.setRank(rewriter, loc, zero);
515
516 // Due to the current way of handling unranked memref results escaping, we
517 // have to actually construct a ranked underlying descriptor instead of just
518 // setting its pointer to NULL.
519 SmallVector<Value, 4> sizes;
520 UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
521 desc, sizes);
522 Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
523 loc, getVoidPtrType(), sizes.front(), llvm::None);
524
525 // Populate underlying ranked descriptor.
526 Type elem_ptr_ptr_type = LLVM::LLVMPointerType::get(
527 LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
528
529 Value null = rewriter.create<LLVM::NullOp>(
530 loc, LLVM::LLVMPointerType::get(llvm_elem_type, address_space));
531 UnrankedMemRefDescriptor::setAllocatedPtr(
532 rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, null);
533 UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
534 underlying_desc_ptr,
535 elem_ptr_ptr_type, null);
536 UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
537 underlying_desc_ptr, elem_ptr_ptr_type,
538 zero);
539
540 desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
541 rewriter.replaceOp(op, {desc});
542 return success();
543 }
544 };
545
546 class IsValidMemRefOpConverter
547 : public ConvertOpToLLVMPattern<IsValidMemRefOp> {
548 public:
549 using ConvertOpToLLVMPattern<IsValidMemRefOp>::ConvertOpToLLVMPattern;
550
matchAndRewrite(IsValidMemRefOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const551 LogicalResult matchAndRewrite(
552 IsValidMemRefOp op, OpAdaptor adaptor,
553 ConversionPatternRewriter &rewriter) const override {
554 Location loc = op.getLoc();
555 MemRefDescriptor desc(adaptor.arg());
556
557 // Compare every size in the descriptor to 0 to check num_elements == 0.
558 int64_t rank = op.arg().getType().cast<MemRefType>().getRank();
559 Value is_empty_shape = rewriter.create<LLVM::ConstantOp>(
560 loc, rewriter.getI1Type(), rewriter.getBoolAttr(false));
561 Value zero = createIndexConstant(rewriter, loc, 0);
562 for (int i = 0; i < rank; ++i) {
563 Value size = desc.size(rewriter, loc, i);
564 Value is_zero_size = rewriter.create<LLVM::ICmpOp>(
565 loc, rewriter.getI1Type(), LLVM::ICmpPredicate::eq, size, zero);
566 is_empty_shape =
567 rewriter.create<LLVM::OrOp>(loc, is_empty_shape, is_zero_size);
568 }
569
570 Value ptr = rewriter.create<LLVM::BitcastOp>(
571 loc, getVoidPtrType(), desc.allocatedPtr(rewriter, loc));
572 Value null = rewriter.create<LLVM::NullOp>(loc, getVoidPtrType());
573 Value is_not_nullptr = rewriter.create<LLVM::ICmpOp>(
574 loc, rewriter.getI1Type(), LLVM::ICmpPredicate::ne, ptr, null);
575
576 // Valid memref = ptr != NULL || num_elements == 0;
577 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, is_not_nullptr, is_empty_shape);
578 return success();
579 }
580 };
581
582 } // namespace
583
PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter * converter,RewritePatternSet * patterns)584 void PopulateTFFrameworkToLLVMConversionPatterns(LLVMTypeConverter *converter,
585 RewritePatternSet *patterns) {
586 // clang-format off
587 patterns->add<
588 IsValidMemRefOpConverter,
589 JITCompileFromStrOpConverter,
590 JITExecuteOpConverter,
591 NullContextOpConverter,
592 NullMemRefOpConverter,
593 ReportErrorOpConverter,
594 TFAllocOpConverter,
595 TFDeallocOpConverter>(*converter);
596 // clang-format on
597 }
598
599 } // namespace tf_framework
600 } // namespace kernel_gen
601 } // namespace mlir
602