1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <stdexcept>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/iterator_range.h"
24 #include "llvm/IR/Constants.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h" // from @llvm-project
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
28 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
29 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
30 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/TypeRange.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
39 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
40 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
41
42 namespace mlir {
43 namespace mhlo {
44 namespace {
45
46 // Create a memref descriptor given a pointer and memref type information.
47 struct XLABufferToMemOpConversion
48 : public ConvertOpToLLVMPattern<::mlir::xla_framework::XLABufferToMemOp> {
49 using ConvertOpToLLVMPattern<
50 ::mlir::xla_framework::XLABufferToMemOp>::ConvertOpToLLVMPattern;
51
matchAndRewritemlir::mhlo::__anonf543fb420111::XLABufferToMemOpConversion52 LogicalResult matchAndRewrite(
53 ::mlir::xla_framework::XLABufferToMemOp op, OpAdaptor adaptor,
54 ConversionPatternRewriter &rewriter) const override {
55 auto loc = op.getLoc();
56 auto mem_ref_type = op.getType();
57
58 SmallVector<Value, 4> sizes;
59 SmallVector<Value, 4> strides;
60 Value size_bytes;
61 this->getMemRefDescriptorSizes(loc, mem_ref_type, ValueRange(), rewriter,
62 sizes, strides, size_bytes);
63
64 auto ptr_type = LLVM::LLVMPointerType::get(
65 typeConverter->convertType(mem_ref_type.getElementType()),
66 mem_ref_type.getMemorySpaceAsInt());
67 Value ptr =
68 rewriter.create<LLVM::BitcastOp>(loc, ptr_type, adaptor.buffer());
69
70 Value result = this->createMemRefDescriptor(loc, mem_ref_type, ptr, ptr,
71 sizes, strides, rewriter);
72 rewriter.replaceOp(op, {result});
73 return success();
74 }
75 };
76
77 // Convert to the expected function signature and offer unwrapping for each of
78 // the original arguments.
79 struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
80 using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;
81
LoadValuemlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion82 Value LoadValue(ConversionPatternRewriter &rewriter, Location loc,
83 Value pointer, Value index) const {
84 return rewriter.create<LLVM::LoadOp>(
85 loc, rewriter.create<LLVM::GEPOp>(
86 loc,
87 LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
88 IntegerType::get(rewriter.getContext(), 8))),
89 pointer, index));
90 }
91
convertFuncOpToLLVMFuncOpmlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion92 mlir::func::FuncOp convertFuncOpToLLVMFuncOp(
93 func::FuncOp funcOp, ConversionPatternRewriter &rewriter) const {
94 auto loc = funcOp.getLoc();
95
96 // This signature is predetermined by
97 // tensorflow/compiler/xla/service/cpu/ir_function.cc
98 //
99 // This only works for the global function version that tf.compile uses.
100 // Local functions will only be called by MLIR compiled code, so we can
101 // ignore them.
102 SmallVector<Type, 6> arg_types;
103 arg_types.reserve(6);
104 arg_types.push_back(
105 LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)));
106 arg_types.push_back(
107 LLVM::LLVMPointerType::get(IntegerType::get(rewriter.getContext(), 8)));
108 arg_types.push_back(LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
109 IntegerType::get(rewriter.getContext(), 8))));
110 arg_types.push_back(LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
111 IntegerType::get(rewriter.getContext(), 8))));
112 arg_types.push_back(LLVM::LLVMPointerType::get(
113 IntegerType::get(rewriter.getContext(), 64)));
114 arg_types.push_back(LLVM::LLVMPointerType::get(
115 IntegerType::get(rewriter.getContext(), 64)));
116 auto llvm_type =
117 mlir::FunctionType::get(rewriter.getContext(), arg_types, {});
118
119 if (!llvm_type) return nullptr;
120
121 rewriter.setInsertionPoint(funcOp);
122 auto new_func_op = rewriter.create<mlir::func::FuncOp>(
123 loc, funcOp.getName(), llvm_type, llvm::SmallVector<NamedAttribute>());
124 auto locs = llvm::SmallVector<mlir::Location>(arg_types.size(), loc);
125 Block *new_entry =
126 rewriter.createBlock(&new_func_op.getBody(), {}, arg_types, locs);
127
128 // This assertion might change but is in place for the current
129 // implementation.
130 assert(funcOp.getFunctionType().getNumResults() == 0 &&
131 "xla_entry function lowered with result values when memrefs should "
132 "be caller supplied");
133
134 BlockAndValueMapping mapping;
135 auto num_refs = funcOp.getFunctionType().getNumInputs();
136 auto result_index = 0;
137 for (unsigned i = 0; i < num_refs; ++i) {
138 if (funcOp.getArgAttr(i, "xla_framework.input_mapping")) {
139 Value index = rewriter.create<LLVM::ConstantOp>(
140 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
141 funcOp.getArgAttrOfType<mlir::IntegerAttr>(
142 i, "xla_framework.input_mapping"));
143
144 Value ptr = LoadValue(rewriter, loc, new_entry->getArgument(3), index);
145 mapping.map(funcOp.front().getArgument(i), ptr);
146 } else {
147 Value index = rewriter.create<LLVM::ConstantOp>(
148 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
149 funcOp->getAttrOfType<mlir::IntegerAttr>(
150 "xla_framework.result_mapping"));
151 Value first_load =
152 LoadValue(rewriter, loc, new_entry->getArgument(3), index);
153
154 // Handle multi-value results which are wrapped in a tuple.
155 if (funcOp->hasAttr("xla_framework.result_inner_mapping")) {
156 auto current_index = result_index++;
157 Value inner_index = rewriter.create<LLVM::ConstantOp>(
158 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
159 rewriter.getI32IntegerAttr(static_cast<int32_t>(
160 funcOp
161 ->getAttrOfType<mlir::ArrayAttr>(
162 "xla_framework.result_inner_mapping")
163 .getValue()[current_index]
164 .cast<mlir::IntegerAttr>()
165 .getValue()
166 .getSExtValue())));
167
168 auto ptr =
169 LoadValue(rewriter, loc, new_entry->getArgument(3), inner_index);
170 mapping.map(funcOp.front().getArgument(i), ptr);
171
172 auto ptr_type = LLVM::LLVMPointerType::get(LLVM::LLVMPointerType::get(
173 IntegerType::get(rewriter.getContext(), 8)));
174 first_load =
175 rewriter.create<LLVM::BitcastOp>(loc, ptr_type, first_load);
176
177 Value second_index = rewriter.create<LLVM::ConstantOp>(
178 loc, typeConverter->convertType(rewriter.getIntegerType(32)),
179 rewriter.getI32IntegerAttr(current_index));
180 rewriter.create<LLVM::StoreOp>(
181 loc, ptr,
182 rewriter.create<LLVM::GEPOp>(loc, ptr_type, first_load,
183 llvm::makeArrayRef(second_index)));
184
185 } else {
186 // Non tuple outputs can be simply mapped to the first load op.
187 mapping.map(funcOp.front().getArgument(i), first_load);
188 }
189 }
190 }
191
192 // Clone the region and handle ReturnOps specially as there will be no
193 // return values now.
194 for (auto &op : funcOp.front()) {
195 if (isa<mlir::func::ReturnOp>(op)) {
196 rewriter.create<mlir::func::ReturnOp>(loc, ValueRange());
197 } else {
198 rewriter.clone(op, mapping);
199 }
200 }
201
202 return new_func_op;
203 }
204
matchAndRewritemlir::mhlo::__anonf543fb420111::BarePtrFuncOpConversion205 LogicalResult matchAndRewrite(
206 func::FuncOp funcOp, OpAdaptor,
207 ConversionPatternRewriter &rewriter) const override {
208 // Only outline functions that are globally available.
209 if (!funcOp->hasAttr("xla_entry")) return failure();
210
211 // Store the type of memref-typed arguments before the conversion so that we
212 // can promote them to MemRef descriptor at the beginning of the function.
213 convertFuncOpToLLVMFuncOp(funcOp, rewriter);
214
215 rewriter.eraseOp(funcOp);
216 return success();
217 }
218 };
219
220 class LegalizeXLAFrameworkToLLVMPass
221 : public LegalizeXLAFrameworkToLLVMBase<LegalizeXLAFrameworkToLLVMPass> {
getDependentDialects(DialectRegistry & registry) const222 void getDependentDialects(DialectRegistry ®istry) const override {
223 registry.insert<func::FuncDialect, LLVM::LLVMDialect,
224 xla_framework::XLAFrameworkDialect>();
225 }
226
227 public:
LegalizeXLAFrameworkToLLVMPass()228 explicit LegalizeXLAFrameworkToLLVMPass() {}
229
runOnOperation()230 void runOnOperation() override {
231 ModuleOp m = getOperation();
232
233 // Populate type conversions.
234 MLIRContext *ctx = m.getContext();
235 LLVMTypeConverter type_converter(ctx);
236 type_converter.addConversion([&](::mlir::xla_framework::BufferType) {
237 return LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
238 });
239
240 // Populate patterns.
241 RewritePatternSet patterns(&getContext());
242 patterns.add<XLABufferToMemOpConversion, BarePtrFuncOpConversion>(
243 type_converter, 2);
244 // Set target.
245 ConversionTarget target(*ctx);
246 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
247 target.addIllegalDialect<xla_framework::XLAFrameworkDialect>();
248 target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) {
249 if (llvm::any_of(
250 llvm::concat<const Type>(op.getArgumentTypes(),
251 op.getResultTypes()),
252 [](Type type) { return type.isa<xla_framework::BufferType>(); }))
253 return false;
254 return true;
255 });
256
257 if (failed(applyFullConversion(m, target, std::move(patterns)))) {
258 signalPassFailure();
259 }
260 }
261 };
262
263 } // namespace
264
265 std::unique_ptr<OperationPass<ModuleOp>>
CreateLegalizeXLAFrameworkToLLVMPass()266 CreateLegalizeXLAFrameworkToLLVMPass() {
267 return std::make_unique<LegalizeXLAFrameworkToLLVMPass>();
268 }
269
270 } // namespace mhlo
271 } // namespace mlir
272