xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/outline_with_xla_framework.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <stdexcept>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"  // from @llvm-project
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinDialect.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/TypeRange.h"  // from @llvm-project
29 #include "mlir/Support/LLVM.h"  // from @llvm-project
30 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
32 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
33 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
34 
35 namespace mlir {
36 namespace mhlo {
37 namespace {
38 
39 // Given a FuncOp with only memref args/outputs, create a new function that
40 // wraps/unwraps xla_framework.buffer types and then calls the original
41 // function.
42 //
43 // For example:
44 //   func @func_to_outline(%arg0: memref<?xf32>) -> memref<?xf32>
45 //
46 // Will generate:
47 //   func @func_to_outline_xla_framework(%arg0: !xla_framework.buffer)
48 //     -> !xla_framework.buffer attributes {xla_entry = true} {
49 //    %0 = xla_framework.buffer_to_mem %arg0 : memref<?xf32>
50 //    %1 = call @func_to_outline(%0) : (memref<?xf32>) -> memref<?xf32>
51 //    %2 = xla_framework.mem_to_buffer %1 : memref<?xf32>
52 //    return %2 : !xla_framework.buffer
53 //   }
54 struct OutlineXLAFunc : public RewritePattern {
OutlineXLAFuncmlir::mhlo::__anon9287d38a0111::OutlineXLAFunc55   explicit OutlineXLAFunc(MLIRContext *context, PatternBenefit benefit = 1)
56       : RewritePattern(func::FuncOp::getOperationName(), benefit, context) {}
57 
filterFuncAttributesmlir::mhlo::__anon9287d38a0111::OutlineXLAFunc58   static void filterFuncAttributes(ArrayRef<NamedAttribute> attrs,
59                                    bool argAttrs,
60                                    SmallVectorImpl<NamedAttribute> &result) {
61     for (const auto &attr : attrs) {
62       if (attr.getName() == SymbolTable::getSymbolAttrName() ||
63           attr.getName() == FunctionOpInterface::getTypeAttrName() ||
64           attr.getName() == "std.varargs" ||
65           (argAttrs && attr.getName() == func::FuncOp::getArgDictAttrName()))
66         continue;
67       result.push_back(attr);
68     }
69   }
70 
matchAndRewritemlir::mhlo::__anon9287d38a0111::OutlineXLAFunc71   LogicalResult matchAndRewrite(Operation *op,
72                                 PatternRewriter &rewriter) const override {
73     auto func = dyn_cast<func::FuncOp>(op);
74     auto ctx = rewriter.getContext();
75     auto loc = func.getLoc();
76     SmallVector<Location> locs(func.getFunctionType().getNumInputs(), loc);
77 
78     // Functions should only be outlined once and should only use memrefs
79     if (!func) return failure();
80     if (llvm::any_of(op->getOperandTypes(),
81                      [](Type t) { return !t.isa<MemRefType>(); }) ||
82         op->getNumResults() != 0)
83       return failure();
84     if (func->hasAttr("outlined")) return failure();
85     func->setAttr("outlined", BoolAttr::get(ctx, true));
86 
87     // Prepare new func attribute information
88     func.setSymNameAttr(mlir::StringAttr::get(ctx, func.getName()));
89     SmallVector<Type> operands(func.getFunctionType().getNumInputs(),
90                                ::mlir::xla_framework::BufferType::get(ctx));
91     SmallVector<Type> result_array(func.getFunctionType().getNumResults(),
92                                    ::mlir::xla_framework::BufferType::get(ctx));
93     auto func_type = FunctionType::get(ctx, operands, result_array);
94     SmallVector<NamedAttribute> attrs;
95     filterFuncAttributes(func->getAttrs(), true, attrs);
96     SmallVector<DictionaryAttr> arg_attrs;
97     func.getAllArgAttrs(arg_attrs);
98 
99     // The wrapper function will have the same name but with _xla_framework
100     // appended and will be annotated with the attribute "xla_entry".
101     auto outline_func = rewriter.create<func::FuncOp>(
102         loc, func.getSymName().str() + "_xla_framework", func_type, attrs,
103         arg_attrs);
104     outline_func->setAttr("outlined", BoolAttr::get(ctx, true));
105     outline_func->setAttr("xla_entry", BoolAttr::get(ctx, true));
106     auto *b = rewriter.createBlock(&outline_func.getBody(), {},
107                                    func_type.getInputs(), locs);
108 
109     // Unwrap arguments
110     SmallVector<Value> args;
111     for (const auto &t : llvm::enumerate(func.getFunctionType().getInputs())) {
112       args.push_back(rewriter.create<xla_framework::XLABufferToMemOp>(
113           loc, t.value(), b->getArgument(t.index())));
114     }
115 
116     auto call = rewriter.create<func::CallOp>(
117         loc, func.getSymName(), func.getFunctionType().getResults(), args);
118     // Wrap results
119     SmallVector<Value> results;
120     for (auto t : call.getResults()) {
121       results.push_back(rewriter.create<xla_framework::MemToXLABufferOp>(
122           loc, ::mlir::xla_framework::BufferType::get(ctx), t));
123     }
124 
125     rewriter.create<func::ReturnOp>(loc, results);
126 
127     // Finally, mark the called function as private to prevent users from
128     // accidentally trying to use it.
129     func.setVisibility(SymbolTable::Visibility::Private);
130 
131     return success();
132   }
133 };
134 
135 class OutlineWithXLAFrameworkPass
136     : public OutlineWithXLAFrameworkBase<OutlineWithXLAFrameworkPass> {
getDependentDialects(DialectRegistry & registry) const137   void getDependentDialects(DialectRegistry &registry) const override {
138     registry.insert<xla_framework::XLAFrameworkDialect, mlir::BuiltinDialect>();
139   }
140 
141  public:
OutlineWithXLAFrameworkPass()142   explicit OutlineWithXLAFrameworkPass() {}
143 
runOnOperation()144   void runOnOperation() override {
145     ModuleOp m = getOperation();
146 
147     // Populate type conversions.
148     MLIRContext *ctx = m.getContext();
149 
150     // Populate patterns.
151     RewritePatternSet patterns(&getContext());
152     patterns.add<OutlineXLAFunc>(ctx);
153     //  Set target.
154 
155     if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) {
156       signalPassFailure();
157     }
158     m->walk([](func::FuncOp f) {
159       if (f->hasAttr("outlined")) f->removeAttr("outlined");
160     });
161   }
162 };
163 
164 }  // namespace
165 
CreateOutlineWithXLAFrameworkPass()166 std::unique_ptr<OperationPass<ModuleOp> > CreateOutlineWithXLAFrameworkPass() {
167   return std::make_unique<OutlineWithXLAFrameworkPass>();
168 }
169 
170 }  // namespace mhlo
171 }  // namespace mlir
172