1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
17
18 #include <iterator>
19
20 #include "absl/algorithm/container.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
23 #include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace xla {
27 namespace experimental {
28
29 using mlir::OpBuilder;
30
GetBoundAffineMapFrom(mlir::Operation * op)31 BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
32 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
33 return {load.getAffineMap(),
34 std::vector<mlir::Value>(load.getMapOperands().begin(),
35 load.getMapOperands().end())};
36 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
37 return {store.getAffineMap(),
38 std::vector<mlir::Value>(store.getMapOperands().begin(),
39 store.getMapOperands().end())};
40 } else {
41 CHECK(false);
42 }
43 }
44
CloneWithNewAffineMap(mlir::Operation * op,BoundAffineMap new_affine,OpBuilder builder)45 mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
46 BoundAffineMap new_affine,
47 OpBuilder builder) {
48 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
49 return builder.create<mlir::AffineLoadOp>(
50 builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
51 new_affine.operands);
52 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
53 return builder.create<mlir::AffineStoreOp>(
54 builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
55 new_affine.affine_map, new_affine.operands);
56 } else {
57 CHECK(false);
58 }
59 }
60
IsSimpleLoop(mlir::AffineForOp loop)61 bool IsSimpleLoop(mlir::AffineForOp loop) {
62 return loop.getLowerBoundMap().isSingleConstant() &&
63 loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
64 loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
65 std::next(loop.getRegion().begin()) == loop.getRegion().end();
66 }
67
CreateNestedSimpleLoops(absl::Span<const int64_t> upper_bounds,OpBuilder builder)68 std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
69 absl::Span<const int64_t> upper_bounds, OpBuilder builder) {
70 std::vector<mlir::AffineForOp> loops;
71 loops.reserve(upper_bounds.size());
72 for (int64_t dim : upper_bounds) {
73 auto loop =
74 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
75 loops.push_back(loop);
76 builder = OpBuilder::atBlockTerminator(loop.getBody());
77 }
78 return loops;
79 }
80
SetBoundForSimpleLoop(mlir::AffineForOp loop,mlir::AffineExpr new_bound,OpBuilder builder)81 void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
82 OpBuilder builder) {
83 CHECK(IsSimpleLoop(loop));
84
85 loop.setUpperBoundMap(mlir::AffineMap::get(
86 loop.getUpperBoundMap().getNumDims(),
87 loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
88 }
89
TileLoop(mlir::AffineForOp loop,int64_t size,mlir::AffineForOp target)90 mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
91 mlir::AffineForOp target) {
92 CHECK(IsSimpleLoop(loop));
93 CHECK(IsSimpleLoop(target));
94 {
95 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
96 getPerfectlyNestedLoops(all_loops, loop);
97 CHECK(absl::c_linear_search(all_loops, target));
98 }
99
100 auto builder = OpBuilder::atBlockTerminator(target.getBody());
101
102 auto inner_loop =
103 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
104 {
105 auto& inner_operations = inner_loop.getBody()->getOperations();
106 auto& target_operations = target.getBody()->getOperations();
107
108 inner_operations.splice(inner_operations.begin(), target_operations,
109 target_operations.begin(),
110 std::prev(target_operations.end(), 2));
111
112 mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
113 CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
114 SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
115 }
116
117 for (auto& use :
118 llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
119 mlir::Operation* owner = use.getOwner();
120 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
121 unsigned new_dim = affine_map.operands.size();
122 affine_map.operands.push_back(inner_loop.getInductionVar());
123 std::vector<mlir::AffineExpr> replacements;
124 for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
125 if (affine_map.operands[i] == loop.getInductionVar()) {
126 replacements.push_back(builder.getAffineDimExpr(i) * size +
127 builder.getAffineDimExpr(new_dim));
128 } else {
129 replacements.push_back(builder.getAffineDimExpr(i));
130 }
131 }
132 affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
133 replacements, {}, affine_map.operands.size(), 0);
134 auto new_op = CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
135 owner->replaceAllUsesWith(new_op);
136 owner->erase();
137 }
138 return inner_loop;
139 }
140
SinkPerfectlyNestedLoops(llvm::MutableArrayRef<mlir::AffineForOp> loops,int rotate_amount)141 void SinkPerfectlyNestedLoops(llvm::MutableArrayRef<mlir::AffineForOp> loops,
142 int rotate_amount) {
143 CHECK_GE(rotate_amount, 0);
144 std::vector<unsigned> permutation(loops.size());
145 std::iota(permutation.begin(), permutation.end(), unsigned(0));
146 std::rotate(permutation.begin(),
147 permutation.begin() + loops.size() - rotate_amount,
148 permutation.end());
149 mlir::permuteLoops(loops, permutation);
150 }
151
152 } // namespace experimental
153 } // namespace xla
154