xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/xla_framework_to_llvm_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <stdexcept>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/iterator_range.h"
24 #include "llvm/IR/Constants.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"  // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
29 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/TypeRange.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
41 
42 namespace mlir {
43 namespace mhlo {
44 namespace {
45 
46 // Create a memref descriptor given a pointer and memref type information.
47 struct XLABufferToMemOpConversion
48     : public ConvertOpToLLVMPattern<::mlir::xla_framework::XLABufferToMemOp> {
49   using ConvertOpToLLVMPattern<
50       ::mlir::xla_framework::XLABufferToMemOp>::ConvertOpToLLVMPattern;
51 
matchAndRewritemlir::mhlo::__anonf543fb420111::XLABufferToMemOpConversion52   LogicalResult matchAndRewrite(
53       ::mlir::xla_framework::XLABufferToMemOp op, OpAdaptor adaptor,
54       ConversionPatternRewriter &rewriter) const override {
55     auto loc = op.getLoc();
56     auto mem_ref_type = op.getType();
57 
58     SmallVector<Value, 4> sizes;
59     SmallVector<Value, 4> strides;
60     Value size_bytes;
61     this->getMemRefDescriptorSizes(loc, mem_ref_type, ValueRange(), rewriter,
62                                    sizes, strides, size_bytes);
63 
64     auto ptr_type = LLVM::LLVMPointerType::get(
65         typeConverter->convertType(mem_ref_type.getElementType()),
66         mem_ref_type.getMemorySpaceAsInt());
67     Value ptr =
68         rewriter.create<LLVM::BitcastOp>(loc, ptr_type, adaptor.buffer());
69 
70     Value result = this->createMemRefDescriptor(loc, mem_ref_type, ptr, ptr,
71                                                 sizes, strides, rewriter);
72     rewriter.replaceOp(op, {result});
73     return success();
74   }
75 };
76 
77 // Convert to the expected function signature and offer unwrapping for each of
78 // the original arguments.
79 struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
80   using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
81 
LoadValuemlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion82   Value LoadValue(ConversionPatternRewriter &rewriter, Location loc,
83                   Value pointer, Value index) const {
84     return rewriter.create<LLVM::LoadOp>(
85         loc, rewriter.create<LLVM::GEPOp>(
86                  loc,
87                  LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
88                      IntegerType::get(rewriter.getContext(), 8))),
89                  pointer, index));
90   }
91 
convertFuncOpToLLVMFuncOpmlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion92   mlir::func::FuncOp convertFuncOpToLLVMFuncOp(
93       func::FuncOp funcOp, ConversionPatternRewriter &rewriter) const {
94     auto loc = funcOp.getLoc();
95 
96     // This signature is predetermined by
97     // tensorflow/compiler/xla/service/cpu/ir_function.cc
98     //
99     // This only works for the global function version that tf.compile uses.
100     // Local functions will only be called by MLIR compiled code, so we can
101     // ignore them.
102     SmallVector<Type, 6> arg_types;
103     arg_types.reserve(6);
104     arg_types.push_back(
105         LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)));
106     arg_types.push_back(
107         LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)));
108     arg_types.push_back(LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
109         IntegerType::get(rewriter.getContext(), 8))));
110     arg_types.push_back(LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
111         IntegerType::get(rewriter.getContext(), 8))));
112     arg_types.push_back(LLVM::LLVMPointerType::get(
113         IntegerType::get(rewriter.getContext(), 64)));
114     arg_types.push_back(LLVM::LLVMPointerType::get(
115         IntegerType::get(rewriter.getContext(), 64)));
116     auto llvm_type =
117         mlir::FunctionType::get(rewriter.getContext(), arg_types, {});
118 
119     if (!llvm_type) return nullptr;
120 
121     rewriter.setInsertionPoint(funcOp);
122     auto new_func_op = rewriter.create<mlir::func::FuncOp>(
123         loc, funcOp.getName(), llvm_type, llvm::SmallVector<NamedAttribute>());
124     auto locs = llvm::SmallVector<mlir::Location>(arg_types.size(), loc);
125     Block *new_entry =
126         rewriter.createBlock(&new_func_op.getBody(), {}, arg_types, locs);
127 
128     // This assertion might change but is in place for the current
129     // implementation.
130     assert(funcOp.getFunctionType().getNumResults() == 0 &&
131            "xla_entry function lowered with result values when memrefs should "
132            "be caller supplied");
133 
134     BlockAndValueMapping mapping;
135     auto num_refs = funcOp.getFunctionType().getNumInputs();
136     auto result_index = 0;
137     for (unsigned i = 0; i < num_refs; ++i) {
138       if (funcOp.getArgAttr(i, "xla_framework.input_mapping")) {
139         Value index = rewriter.create<LLVM::ConstantOp>(
140             loc, typeConverter->convertType(rewriter.getIntegerType(32)),
141             funcOp.getArgAttrOfType<mlir::IntegerAttr>(
142                 i, "xla_framework.input_mapping"));
143 
144         Value ptr = LoadValue(rewriter, loc, new_entry->getArgument(3), index);
145         mapping.map(funcOp.front().getArgument(i), ptr);
146       } else {
147         Value index = rewriter.create<LLVM::ConstantOp>(
148             loc, typeConverter->convertType(rewriter.getIntegerType(32)),
149             funcOp->getAttrOfType<mlir::IntegerAttr>(
150                 "xla_framework.result_mapping"));
151         Value first_load =
152             LoadValue(rewriter, loc, new_entry->getArgument(3), index);
153 
154         // Handle multi-value results which are wrapped in a tuple.
155         if (funcOp->hasAttr("xla_framework.result_inner_mapping")) {
156           auto current_index = result_index++;
157           Value inner_index = rewriter.create<LLVM::ConstantOp>(
158               loc, typeConverter->convertType(rewriter.getIntegerType(32)),
159               rewriter.getI32IntegerAttr(static_cast<int32_t>(
160                   funcOp
161                       ->getAttrOfType<mlir::ArrayAttr>(
162                           "xla_framework.result_inner_mapping")
163                       .getValue()[current_index]
164                       .cast<mlir::IntegerAttr>()
165                       .getValue()
166                       .getSExtValue())));
167 
168           auto ptr =
169               LoadValue(rewriter, loc, new_entry->getArgument(3), inner_index);
170           mapping.map(funcOp.front().getArgument(i), ptr);
171 
172           auto ptr_type = LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
173               IntegerType::get(rewriter.getContext(), 8)));
174           first_load =
175               rewriter.create<LLVM::BitcastOp>(loc, ptr_type, first_load);
176 
177           Value second_index = rewriter.create<LLVM::ConstantOp>(
178               loc, typeConverter->convertType(rewriter.getIntegerType(32)),
179               rewriter.getI32IntegerAttr(current_index));
180           rewriter.create<LLVM::StoreOp>(
181               loc, ptr,
182               rewriter.create<LLVM::GEPOp>(loc, ptr_type, first_load,
183                                            llvm::makeArrayRef(second_index)));
184 
185         } else {
186           // Non tuple outputs can be simply mapped to the first load op.
187           mapping.map(funcOp.front().getArgument(i), first_load);
188         }
189       }
190     }
191 
192     // Clone the region and handle ReturnOps specially as there will be no
193     // return values now.
194     for (auto &op : funcOp.front()) {
195       if (isa<mlir::func::ReturnOp>(op)) {
196         rewriter.create<mlir::func::ReturnOp>(loc, ValueRange());
197       } else {
198         rewriter.clone(op, mapping);
199       }
200     }
201 
202     return new_func_op;
203   }
204 
matchAndRewritemlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion205   LogicalResult matchAndRewrite(
206       func::FuncOp funcOp, OpAdaptor,
207       ConversionPatternRewriter &rewriter) const override {
208     // Only outline functions that are globally available.
209     if (!funcOp->hasAttr("xla_entry")) return failure();
210 
211     // Store the type of memref-typed arguments before the conversion so that we
212     // can promote them to MemRef descriptor at the beginning of the function.
213     convertFuncOpToLLVMFuncOp(funcOp, rewriter);
214 
215     rewriter.eraseOp(funcOp);
216     return success();
217   }
218 };
219 
220 class LegalizeXLAFrameworkToLLVMPass
221     : public LegalizeXLAFrameworkToLLVMBase<LegalizeXLAFrameworkToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const222   void getDependentDialects(DialectRegistry &registry) const override {
223     registry.insert<func::FuncDialect, LLVM::LLVMDialect,
224                     xla_framework::XLAFrameworkDialect>();
225   }
226 
227  public:
LegalizeXLAFrameworkToLLVMPass()228   explicit LegalizeXLAFrameworkToLLVMPass() {}
229 
runOnOperation()230   void runOnOperation() override {
231     ModuleOp m = getOperation();
232 
233     // Populate type conversions.
234     MLIRContext *ctx = m.getContext();
235     LLVMTypeConverter type_converter(ctx);
236     type_converter.addConversion([&](::mlir::xla_framework::BufferType) {
237       return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
238     });
239 
240     // Populate patterns.
241     RewritePatternSet patterns(&getContext());
242     patterns.add<XLABufferToMemOpConversion, BarePtrFuncOpConversion>(
243         type_converter, 2);
244     //  Set target.
245     ConversionTarget target(*ctx);
246     target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
247     target.addIllegalDialect<xla_framework::XLAFrameworkDialect>();
248     target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
249       if (llvm::any_of(
250               llvm::concat<const Type>(op.getArgumentTypes(),
251                                        op.getResultTypes()),
252               [](Type type) { return type.isa<xla_framework::BufferType>(); }))
253         return false;
254       return true;
255     });
256 
257     if (failed(applyFullConversion(m, target, std::move(patterns)))) {
258       signalPassFailure();
259     }
260   }
261 };
262 
263 }  // namespace
264 
265 std::unique_ptr<OperationPass<ModuleOp>>
CreateLegalizeXLAFrameworkToLLVMPass()266 CreateLegalizeXLAFrameworkToLLVMPass() {
267   return std::make_unique<LegalizeXLAFrameworkToLLVMPass>();
268 }
269 
270 }  // namespace mhlo
271 }  // namespace mlir
272