xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/prepare_for_export.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for some optimizations to reduce size on export.
17 
18 #include <cstdint>
19 #include <memory>
20 
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
28 #include "mlir/IR/Matchers.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "mlir/Pass/Pass.h"  // from @llvm-project
32 #include "mlir/Support/LLVM.h"  // from @llvm-project
33 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37 
38 #define DEBUG_TYPE "xla-prepare-for-export"
39 
40 namespace mlir {
41 namespace mhlo {
42 namespace {
43 
44 // Prepare module for export to XLA HLO.
45 struct PrepareForExportPass
46     : public PrepareForExportPassBase<PrepareForExportPass> {
47   void runOnOperation() override;
48 };
49 
50 }  // end namespace
51 
52 // Materializes some splat before export because it may be more efficient in
53 // HLOInstruction.
prepareConstantOp(Operation * op,SplatElementsAttr attr)54 void prepareConstantOp(Operation *op, SplatElementsAttr attr) {
55   // Arbitrarialy chosen "small" number. This could be chosen based on the
56   // proto size too.
57   if (attr.getNumElements() < 32) return;
58   ShapedType return_type = op->getResultTypes().front().cast<ShapedType>();
59   ImplicitLocOpBuilder b(op->getLoc(), op);
60   ConstantOp cst;
61   if (auto complexTy = return_type.getElementType().dyn_cast<ComplexType>()) {
62     auto tensorType = RankedTensorType::get({}, return_type.getElementType());
63     assert(complexTy.getElementType().isa<FloatType>() &&
64            "unexpected int complex in MHLO");
65     auto complexVal = attr.getSplatValue<std::complex<APFloat>>();
66     cst = b.create<ConstantOp>(DenseElementsAttr::get(tensorType, complexVal));
67   } else {
68     cst = b.create<ConstantOp>(attr.getSplatValue<Attribute>());
69   }
70   auto broadcast =
71       b.create<BroadcastInDimOp>(return_type, cst, b.getI64TensorAttr({}));
72   op->replaceAllUsesWith(broadcast);
73   op->erase();
74 }
75 
76 // Ensure that there aren't any implicit capture before exporting.
prepareWhileOp(WhileOp while_op)77 void prepareWhileOp(WhileOp while_op) {
78   llvm::SetVector<Value> implicit_inputs;
79   getUsedValuesDefinedAbove(while_op->getRegions(), implicit_inputs);
80   // Each captured value has to be passed as operand to the while, become then
81   // an operand to the condition region and the body region, and an extra
82   // operand to the return op in the body. It also becomes an extra result for
83   // the while operation, even if it is unused.
84   // We'll process the captured values one at a time and patch the body and
85   // condition regions as we go, but we'll accumulate the new operands and
86   // result type and recreate a new while op to replace the existing one at the
87   // end.
88   SmallVector<Type> returned_types(while_op->getResultTypes().begin(),
89                                    while_op->getResultTypes().end());
90   SmallVector<Value> operands(while_op->getOperands().begin(),
91                               while_op->getOperands().end());
92   Region &cond_region = while_op.cond();
93   Region &body_region = while_op.body();
94 
95   for (Value input : implicit_inputs) {
96     returned_types.push_back(input.getType());
97     operands.push_back(input);
98 
99     Value cond_arg =
100         cond_region.front().addArgument(input.getType(), input.getLoc());
101     Value body_arg =
102         body_region.front().addArgument(input.getType(), input.getLoc());
103     for (OpOperand &operand : llvm::make_early_inc_range(input.getUses())) {
104       if (cond_region.isAncestor(operand.getOwner()->getParentRegion()))
105         operand.set(cond_arg);
106       else if (body_region.isAncestor(operand.getOwner()->getParentRegion()))
107         operand.set(body_arg);
108     }
109     auto return_op = cast<mhlo::ReturnOp>(body_region.front().back());
110     return_op->insertOperands(return_op->getNumOperands(), body_arg);
111   }
112   OpBuilder builder(while_op);
113   auto new_while_op = builder.create<mhlo::WhileOp>(while_op.getLoc(),
114                                                     returned_types, operands);
115   new_while_op.cond().getBlocks().clear();
116   new_while_op.cond().takeBody(while_op.cond());
117   new_while_op.body().getBlocks().clear();
118   new_while_op.body().takeBody(while_op.body());
119   for (auto zipped_results :
120        llvm::zip_first(while_op.getResults(), new_while_op.getResults()))
121     std::get<0>(zipped_results).replaceAllUsesWith(std::get<1>(zipped_results));
122   while_op->erase();
123 }
124 
prepareBroadcastInDim(BroadcastInDimOp bcast)125 void prepareBroadcastInDim(BroadcastInDimOp bcast) {
126   DenseIntElementsAttr dims = bcast.broadcast_dimensions();
127   // If dimensions aren't sorted, there is a transpose fused into the op, which
128   // XLA Builder does not support, we unfuse here.
129   if (llvm::is_sorted(dims.getValues<int64_t>())) return;
130 
131   // We need to compute a permutation that sorts the dimension before the
132   // broadcast.
133   // If the dims are [2, 4, 1], we create an array of indices: [0, 1, 2] and we
134   // sort it using the values of the first array to produce [2, 0, 1] which
135   // gives us the operand for the transpose.
136   SmallVector<int64_t> transposedDim =
137       to_vector(llvm::seq<int64_t>(0, dims.size()));
138   auto rawDims = dims.getValues<int64_t>();
139   llvm::sort(transposedDim, [&](int64_t lhs, int64_t rhs) {
140     return rawDims[lhs] < rawDims[rhs];
141   });
142   OpBuilder builder(bcast);
143   bcast.setOperand(builder.create<TransposeOp>(
144       bcast.getLoc(), bcast.operand(),
145       DenseIntElementsAttr::get(dims.getType(), transposedDim)));
146   // Now reuse the original broadcast_dimensions and sort it.
147   transposedDim.assign(rawDims.begin(), rawDims.end());
148   llvm::sort(transposedDim);
149   bcast.broadcast_dimensionsAttr(
150       DenseIntElementsAttr::get(dims.getType(), transposedDim));
151 }
152 
runOnOperation()153 void PrepareForExportPass::runOnOperation() {
154   getOperation().walk([&](Operation *op) {
155     mlir::SplatElementsAttr attr;
156     if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
157 
158     if (auto whileOp = dyn_cast<WhileOp>(op)) return prepareWhileOp(whileOp);
159     if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
160       return prepareBroadcastInDim(bcastOp);
161   });
162 }
163 
CreatePrepareForExport()164 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareForExport() {
165   return std::make_unique<PrepareForExportPass>();
166 }
167 
168 }  // end namespace mhlo
169 }  // end namespace mlir
170