1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <stdexcept>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinDialect.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/TypeRange.h" // from @llvm-project
29 #include "mlir/Support/LLVM.h" // from @llvm-project
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
32 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
34
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38
39 // Given a FuncOp with only memref args/outputs, create a new function that
40 // wraps/unwraps xla_framework.buffer types and then calls the original
41 // function.
42 //
43 // For example:
44 // func @func_to_outline(%arg0: memref<?xf32>) -> memref<?xf32>
45 //
46 // Will generate:
47 // func @func_to_outline_xla_framework(%arg0: !xla_framework.buffer)
48 // -> !xla_framework.buffer attributes {xla_entry = true} {
49 // %0 = xla_framework.buffer_to_mem %arg0 : memref<?xf32>
50 // %1 = call @func_to_outline(%0) : (memref<?xf32>) -> memref<?xf32>
51 // %2 = xla_framework.mem_to_buffer %1 : memref<?xf32>
52 // return %2 : !xla_framework.buffer
53 // }
54 struct OutlineXLAFunc : public RewritePattern {
OutlineXLAFuncmlir::mhlo::__anon9287d38a0111::OutlineXLAFunc55 explicit OutlineXLAFunc(MLIRContext *context, PatternBenefit benefit = 1)
56 : RewritePattern(func::FuncOp::getOperationName(), benefit, context) {}
57
filterFuncAttributesmlir::mhlo::__anon9287d38a0111::OutlineXLAFunc58 static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
59 bool argAttrs,
60 SmallVectorImpl<NamedAttribute> &result) {
61 for (const auto &attr : attrs) {
62 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
63 attr.getName() == FunctionOpInterface::getTypeAttrName() ||
64 attr.getName() == "std.varargs" ||
65 (argAttrs && attr.getName() == func::FuncOp::getArgDictAttrName()))
66 continue;
67 result.push_back(attr);
68 }
69 }
70
matchAndRewritemlir::mhlo::__anon9287d38a0111::OutlineXLAFunc71 LogicalResult matchAndRewrite(Operation *op,
72 PatternRewriter &rewriter) const override {
73 auto func = dyn_cast<func::FuncOp>(op);
74 auto ctx = rewriter.getContext();
75 auto loc = func.getLoc();
76 SmallVector<Location> locs(func.getFunctionType().getNumInputs(), loc);
77
78 // Functions should only be outlined once and should only use memrefs
79 if (!func) return failure();
80 if (llvm::any_of(op->getOperandTypes(),
81 [](Type t) { return !t.isa<MemRefType>(); }) ||
82 op->getNumResults() != 0)
83 return failure();
84 if (func->hasAttr("outlined")) return failure();
85 func->setAttr("outlined", BoolAttr::get(ctx, true));
86
87 // Prepare new func attribute information
88 func.setSymNameAttr(mlir::StringAttr::get(ctx, func.getName()));
89 SmallVector<Type> operands(func.getFunctionType().getNumInputs(),
90 ::mlir::xla_framework::BufferType::get(ctx));
91 SmallVector<Type> result_array(func.getFunctionType().getNumResults(),
92 ::mlir::xla_framework::BufferType::get(ctx));
93 auto func_type = FunctionType::get(ctx, operands, result_array);
94 SmallVector<NamedAttribute> attrs;
95 filterFuncAttributes(func->getAttrs(), true, attrs);
96 SmallVector<DictionaryAttr> arg_attrs;
97 func.getAllArgAttrs(arg_attrs);
98
99 // The wrapper function will have the same name but with _xla_framework
100 // appended and will be annotated with the attribute "xla_entry".
101 auto outline_func = rewriter.create<func::FuncOp>(
102 loc, func.getSymName().str() + "_xla_framework", func_type, attrs,
103 arg_attrs);
104 outline_func->setAttr("outlined", BoolAttr::get(ctx, true));
105 outline_func->setAttr("xla_entry", BoolAttr::get(ctx, true));
106 auto *b = rewriter.createBlock(&outline_func.getBody(), {},
107 func_type.getInputs(), locs);
108
109 // Unwrap arguments
110 SmallVector<Value> args;
111 for (const auto &t : llvm::enumerate(func.getFunctionType().getInputs())) {
112 args.push_back(rewriter.create<xla_framework::XLABufferToMemOp>(
113 loc, t.value(), b->getArgument(t.index())));
114 }
115
116 auto call = rewriter.create<func::CallOp>(
117 loc, func.getSymName(), func.getFunctionType().getResults(), args);
118 // Wrap results
119 SmallVector<Value> results;
120 for (auto t : call.getResults()) {
121 results.push_back(rewriter.create<xla_framework::MemToXLABufferOp>(
122 loc, ::mlir::xla_framework::BufferType::get(ctx), t));
123 }
124
125 rewriter.create<func::ReturnOp>(loc, results);
126
127 // Finally, mark the called function as private to prevent users from
128 // accidentally trying to use it.
129 func.setVisibility(SymbolTable::Visibility::Private);
130
131 return success();
132 }
133 };
134
135 class OutlineWithXLAFrameworkPass
136 : public OutlineWithXLAFrameworkBase<OutlineWithXLAFrameworkPass> {
getDependentDialects(DialectRegistry & registry) const137 void getDependentDialects(DialectRegistry ®istry) const override {
138 registry.insert<xla_framework::XLAFrameworkDialect, mlir::BuiltinDialect>();
139 }
140
141 public:
OutlineWithXLAFrameworkPass()142 explicit OutlineWithXLAFrameworkPass() {}
143
runOnOperation()144 void runOnOperation() override {
145 ModuleOp m = getOperation();
146
147 // Populate type conversions.
148 MLIRContext *ctx = m.getContext();
149
150 // Populate patterns.
151 RewritePatternSet patterns(&getContext());
152 patterns.add<OutlineXLAFunc>(ctx);
153 // Set target.
154
155 if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
156 signalPassFailure();
157 }
158 m->walk([](func::FuncOp f) {
159 if (f->hasAttr("outlined")) f->removeAttr("outlined");
160 });
161 }
162 };
163
164 } // namespace
165
CreateOutlineWithXLAFrameworkPass()166 std::unique_ptr<OperationPass<ModuleOp> > CreateOutlineWithXLAFrameworkPass() {
167 return std::make_unique<OutlineWithXLAFrameworkPass>();
168 }
169
170 } // namespace mhlo
171 } // namespace mlir
172