1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for some optimizations to reduce size on export.
17
18 #include <cstdint>
19 #include <memory>
20
21 #include "llvm/ADT/STLExtras.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
28 #include "mlir/IR/Matchers.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "mlir/Pass/Pass.h" // from @llvm-project
32 #include "mlir/Support/LLVM.h" // from @llvm-project
33 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37
38 #define DEBUG_TYPE "xla-prepare-for-export"
39
40 namespace mlir {
41 namespace mhlo {
42 namespace {
43
44 // Prepare module for export to XLA HLO.
45 struct PrepareForExportPass
46 : public PrepareForExportPassBase<PrepareForExportPass> {
47 void runOnOperation() override;
48 };
49
50 } // end namespace
51
52 // Materializes some splat before export because it may be more efficient in
53 // HLOInstruction.
prepareConstantOp(Operation * op,SplatElementsAttr attr)54 void prepareConstantOp(Operation *op, SplatElementsAttr attr) {
55 // Arbitrarialy chosen "small" number. This could be chosen based on the
56 // proto size too.
57 if (attr.getNumElements() < 32) return;
58 ShapedType return_type = op->getResultTypes().front().cast<ShapedType>();
59 ImplicitLocOpBuilder b(op->getLoc(), op);
60 ConstantOp cst;
61 if (auto complexTy = return_type.getElementType().dyn_cast<ComplexType>()) {
62 auto tensorType = RankedTensorType::get({}, return_type.getElementType());
63 assert(complexTy.getElementType().isa<FloatType>() &&
64 "unexpected int complex in MHLO");
65 auto complexVal = attr.getSplatValue<std::complex<APFloat>>();
66 cst = b.create<ConstantOp>(DenseElementsAttr::get(tensorType, complexVal));
67 } else {
68 cst = b.create<ConstantOp>(attr.getSplatValue<Attribute>());
69 }
70 auto broadcast =
71 b.create<BroadcastInDimOp>(return_type, cst, b.getI64TensorAttr({}));
72 op->replaceAllUsesWith(broadcast);
73 op->erase();
74 }
75
76 // Ensure that there aren't any implicit capture before exporting.
prepareWhileOp(WhileOp while_op)77 void prepareWhileOp(WhileOp while_op) {
78 llvm::SetVector<Value> implicit_inputs;
79 getUsedValuesDefinedAbove(while_op->getRegions(), implicit_inputs);
80 // Each captured value has to be passed as operand to the while, become then
81 // an operand to the condition region and the body region, and an extra
82 // operand to the return op in the body. It also becomes an extra result for
83 // the while operation, even if it is unused.
84 // We'll process the captured values one at a time and patch the body and
85 // condition regions as we go, but we'll accumulate the new operands and
86 // result type and recreate a new while op to replace the existing one at the
87 // end.
88 SmallVector<Type> returned_types(while_op->getResultTypes().begin(),
89 while_op->getResultTypes().end());
90 SmallVector<Value> operands(while_op->getOperands().begin(),
91 while_op->getOperands().end());
92 Region &cond_region = while_op.cond();
93 Region &body_region = while_op.body();
94
95 for (Value input : implicit_inputs) {
96 returned_types.push_back(input.getType());
97 operands.push_back(input);
98
99 Value cond_arg =
100 cond_region.front().addArgument(input.getType(), input.getLoc());
101 Value body_arg =
102 body_region.front().addArgument(input.getType(), input.getLoc());
103 for (OpOperand &operand : llvm::make_early_inc_range(input.getUses())) {
104 if (cond_region.isAncestor(operand.getOwner()->getParentRegion()))
105 operand.set(cond_arg);
106 else if (body_region.isAncestor(operand.getOwner()->getParentRegion()))
107 operand.set(body_arg);
108 }
109 auto return_op = cast<mhlo::ReturnOp>(body_region.front().back());
110 return_op->insertOperands(return_op->getNumOperands(), body_arg);
111 }
112 OpBuilder builder(while_op);
113 auto new_while_op = builder.create<mhlo::WhileOp>(while_op.getLoc(),
114 returned_types, operands);
115 new_while_op.cond().getBlocks().clear();
116 new_while_op.cond().takeBody(while_op.cond());
117 new_while_op.body().getBlocks().clear();
118 new_while_op.body().takeBody(while_op.body());
119 for (auto zipped_results :
120 llvm::zip_first(while_op.getResults(), new_while_op.getResults()))
121 std::get<0>(zipped_results).replaceAllUsesWith(std::get<1>(zipped_results));
122 while_op->erase();
123 }
124
prepareBroadcastInDim(BroadcastInDimOp bcast)125 void prepareBroadcastInDim(BroadcastInDimOp bcast) {
126 DenseIntElementsAttr dims = bcast.broadcast_dimensions();
127 // If dimensions aren't sorted, there is a transpose fused into the op, which
128 // XLA Builder does not support, we unfuse here.
129 if (llvm::is_sorted(dims.getValues<int64_t>())) return;
130
131 // We need to compute a permutation that sorts the dimension before the
132 // broadcast.
133 // If the dims are [2, 4, 1], we create an array of indices: [0, 1, 2] and we
134 // sort it using the values of the first array to produce [2, 0, 1] which
135 // gives us the operand for the transpose.
136 SmallVector<int64_t> transposedDim =
137 to_vector(llvm::seq<int64_t>(0, dims.size()));
138 auto rawDims = dims.getValues<int64_t>();
139 llvm::sort(transposedDim, [&](int64_t lhs, int64_t rhs) {
140 return rawDims[lhs] < rawDims[rhs];
141 });
142 OpBuilder builder(bcast);
143 bcast.setOperand(builder.create<TransposeOp>(
144 bcast.getLoc(), bcast.operand(),
145 DenseIntElementsAttr::get(dims.getType(), transposedDim)));
146 // Now reuse the original broadcast_dimensions and sort it.
147 transposedDim.assign(rawDims.begin(), rawDims.end());
148 llvm::sort(transposedDim);
149 bcast.broadcast_dimensionsAttr(
150 DenseIntElementsAttr::get(dims.getType(), transposedDim));
151 }
152
runOnOperation()153 void PrepareForExportPass::runOnOperation() {
154 getOperation().walk([&](Operation *op) {
155 mlir::SplatElementsAttr attr;
156 if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr);
157
158 if (auto whileOp = dyn_cast<WhileOp>(op)) return prepareWhileOp(whileOp);
159 if (auto bcastOp = dyn_cast<BroadcastInDimOp>(op))
160 return prepareBroadcastInDim(bcastOp);
161 });
162 }
163
CreatePrepareForExport()164 std::unique_ptr<OperationPass<func::FuncOp>> CreatePrepareForExport() {
165 return std::make_unique<PrepareForExportPass>();
166 }
167
168 } // end namespace mhlo
169 } // end namespace mlir
170