1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Pass/PassManager.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
35
36 #define DEBUG_TYPE "tf-layout-optimization"
37
38 namespace mlir {
39 namespace TF {
40
41 namespace {
42
43 // Helper method that returns an op from 'transpose_ops' that match criteria
44 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)45 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
46 const SmallVector<int64_t, 4>& permutation,
47 Operation* op, ConstOp permutation_op,
48 SmallVector<TransposeOp, 2>* transpose_ops) {
49 for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
50 auto tranpose_op = *it;
51 for (auto tranpose_operand : tranpose_op.getOperands()) {
52 auto ranked_tranpose_type =
53 tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
54 if (!ranked_tranpose_type) continue;
55 if (ranked_tranpose_type.getRank() == permutation.size() &&
56 operand->get().getType() ==
57 ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
58 TransposeOp transpose = tranpose_op;
59 transpose.getOperation()->moveBefore(op);
60 transpose.setOperand(0, operand->get());
61 transpose.setOperand(1, permutation_op);
62 transpose_ops->erase(it);
63 return transpose;
64 }
65 }
66 }
67 return nullptr;
68 }
69
70 // LayoutAssignmentPass assigns optimal data layout (data format) for all
71 // layout sensitive operations.
72 class LayoutAssignmentPass
73 : public LayoutAssignmentPassBase<LayoutAssignmentPass> {
74 public:
75 LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)76 explicit LayoutAssignmentPass(const std::string& force_data_format) {
77 force_data_format_ = force_data_format;
78 }
79
LayoutAssignmentPass(const LayoutAssignmentPass & pass)80 LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
81
82 void runOnOperation() final;
83 };
84
85 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
86 // the basic block where they are defined. This will allow canonicalzer to
87 // delete redundant transposes.
88 class MoveTransposesPass : public MoveTransposesPassBase<MoveTransposesPass> {
89 public:
90 MoveTransposesPass() = default;
MoveTransposesPass(MoveTransposeDirection direction,bool fold_transpose_in_ops)91 explicit MoveTransposesPass(MoveTransposeDirection direction,
92 bool fold_transpose_in_ops) {
93 this->direction_ = direction;
94 this->fold_transpose_in_ops_ = fold_transpose_in_ops;
95 }
MoveTransposesPass(const MoveTransposesPass & pass)96 MoveTransposesPass(const MoveTransposesPass& pass) {}
97
98 void runOnOperation() final;
99 };
100
101 using Permutation = SmallVector<int64_t, 4>;
102
runOnOperation()103 void LayoutAssignmentPass::runOnOperation() {
104 func::FuncOp func = getOperation();
105
106 // Get runtime devices information from the closest parent module.
107 RuntimeDevices devices;
108 if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
109 &devices)))
110 return signalPassFailure();
111
112 // If there is no runtime device information and data format is not explicitly
113 // forced, there is nothing to do.
114 if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
115
116 func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
117 // Get desired op data format.
118 StringRef target_data_format = force_data_format_;
119 if (target_data_format.empty()) {
120 target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
121 }
122
123 // Skip ops that already use target data format.
124 auto data_format = layout_sensitive_interface.data_format();
125 if (data_format == target_data_format) return;
126
127 // Transpose arguments into the target data format.
128 Permutation args_permutation =
129 GetDataFormatPermutation(data_format, target_data_format);
130
131 // Transpose results back to the original data format.
132 Permutation res_permutation =
133 GetDataFormatPermutation(target_data_format, data_format);
134
135 if (args_permutation.empty() || res_permutation.empty()) return;
136
137 mlir::Operation* op = layout_sensitive_interface.getOperation();
138 Location loc = op->getLoc();
139 OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
140
141 auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
142 auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
143 return DenseIntElementsAttr::get(perm_ty, permutation);
144 };
145
146 // Change operation data format.
147 if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
148 return;
149
150 // Permute arguments into the target data format.
151 builder.setInsertionPoint(op);
152 auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
153
154 for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
155 op->setOperand(
156 arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
157 }
158
159 // Permute results back to the original data format.
160 builder.setInsertionPointAfter(op);
161 auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
162
163 for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
164 OpResult result = op->getResult(res);
165
166 auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
167 result.replaceAllUsesWith(transposed_res);
168 transposed_res.setOperand(0, result);
169 }
170 });
171 }
172
173 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)174 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
175 // TODO(ezhulenev): Move transpose across layout sensitive operations.
176 if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
177
178 // Transpose operations that use operation results.
179 SmallVector<TransposeOp, 2> transpose_ops;
180
181 // Constant operation that defines permutation indices for result transposes.
182 ConstOp permutation_op;
183
184 // All operation results must be used by transpose operations with the same
185 // permutation indices.
186 for (OpResult result : op->getResults()) {
187 for (Operation* user : result.getUsers()) {
188 // Result user must be a transpose operation.
189 TransposeOp transpose = dyn_cast<TransposeOp>(user);
190 if (!transpose) return;
191
192 // With permutation defined by constant operation.
193 ConstOp perm =
194 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
195 if (!perm) return;
196
197 // With the same permutation indices.
198 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
199 if (!dense_elem_attr) return;
200
201 if (!permutation_op) permutation_op = perm;
202
203 // Check that permutation matches for all result transposes.
204 if (perm.value() != permutation_op.value()) return;
205
206 // Add a transpose operation for later reuse.
207 transpose_ops.push_back(transpose);
208 }
209 }
210
211 // Nothing to do here.
212 if (!permutation_op || transpose_ops.empty()) return;
213 SmallVector<int64_t, 4> permutation;
214 auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
215 for (const auto& value : perm_attr.getValues<APInt>())
216 permutation.push_back(value.getSExtValue());
217
218 // We want to make sure the shape of the operand equals the transposed shape.
219 // mismatch can happen if 'op' supports broadcasting and the operands have
220 // different ranks.
221 if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
222 auto transpose_op = *transpose_ops.begin();
223 auto result_type =
224 transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
225 auto is_valid_move =
226 llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
227 auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
228 return result_type && operand_type && result_type.hasRank() &&
229 operand_type.hasRank() &&
230 result_type.getRank() == operand_type.getRank();
231 });
232 if (!is_valid_move) return;
233 }
234
235 // At this point we checked that we can safely move Transpose node before
236 // `op`, and bypass all result transposes.
237 Location loc = op->getLoc();
238
239 // Move constant op defining result permutation to the beginning of the block.
240 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
241
242 // Bypass Transpose nodes for all results.
243 for (OpResult result : op->getResults()) {
244 result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
245 for (Operation* transpose : result.getUsers()) {
246 transpose->getResult(0).replaceAllUsesWith(result);
247 }
248 }
249
250 // Maybe add a Transpose node for all operands (or reuse existing transposes).
251 OpBuilder builder(op);
252 builder.setInsertionPoint(op);
253
254 for (OpOperand& operand : op->getOpOperands()) {
255 // Try to push transpose further up.
256 if (Operation* operand_op = operand.get().getDefiningOp())
257 work_list->push_back(operand_op);
258
259 // Try to reuse result transposes.
260 TransposeOp transpose = ReuseExistingTranspose(
261 &operand, permutation, op, permutation_op, &transpose_ops);
262 // If no transpose available for using, create new one.
263 if (!transpose)
264 transpose =
265 builder.create<TransposeOp>(loc, operand.get(), permutation_op);
266
267 operand.set(transpose);
268 }
269
270 // Remove unused transpose operations.
271 while (!transpose_ops.empty()) {
272 TransposeOp transpose = transpose_ops.pop_back_val();
273 transpose.erase();
274 }
275 }
276
277 // Revert the permutation applied in `type`.
ReversePermuteShapedType(mlir::ShapedType type,ArrayRef<int64_t> permutation)278 static mlir::ShapedType ReversePermuteShapedType(
279 mlir::ShapedType type, ArrayRef<int64_t> permutation) {
280 if (!type.hasRank()) return type;
281
282 auto shape = type.getShape();
283 SmallVector<int64_t, 4> new_shape(shape.size());
284
285 for (int i = 0; i < permutation.size(); ++i) {
286 int64_t index = permutation[i];
287 assert(index < shape.size());
288 new_shape[index] = shape[i];
289 }
290
291 return type.clone(new_shape);
292 }
293
294 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)295 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
296 bool fold_transpose_in_ops) {
297 // Indices of operands and results that depend on data layout.
298 SmallVector<unsigned, 4> layout_dependent_operands;
299 SmallVector<unsigned, 4> layout_dependent_results;
300
301 auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
302 bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
303
304 if (fold_operands && fold_transpose_in_ops) {
305 layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
306 layout_dependent_results = fold_operands.GetLayoutDependentResults();
307
308 } else if (layout_agnostic) {
309 // For layout agnostic operation (e.g. element wise operations) all operands
310 // and results must have the same data layout.
311 for (unsigned i = 0; i < op->getNumOperands(); ++i)
312 layout_dependent_operands.push_back(i);
313 for (unsigned i = 0; i < op->getNumResults(); ++i)
314 layout_dependent_results.push_back(i);
315 }
316
317 // Transpose operations that are operands of the `op`.
318 SmallVector<TransposeOp, 2> transpose_ops;
319
320 // Constant operation that defines permutation indices for operand transposes.
321 ConstOp permutation_op;
322
323 // Layout dependent operands must be transpose operations with the same
324 // permutation indices.
325 for (unsigned idx : layout_dependent_operands) {
326 OpOperand& operand = op->getOpOperand(idx);
327
328 // Operand must be defined by a transpose op.
329 TransposeOp transpose =
330 dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
331 if (!transpose) return;
332
333 // With permutation defined by constant operation.
334 ConstOp perm =
335 dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
336 if (!perm) return;
337
338 // With the same permutation indices.
339 auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
340 if (!dense_elem_attr) return;
341
342 if (!permutation_op) permutation_op = perm;
343
344 // Check that permutation matches for all result transposes.
345 if (perm.value() != permutation_op.value()) return;
346
347 // Add a transpose operation for later reuse only if it's used once.
348 if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
349 }
350
351 // Nothing to do here.
352 if (!permutation_op) return;
353
354 // All results after transpose must preserve the original result type.
355 SmallVector<Type, 4> original_type(op->getNumResults());
356 for (unsigned idx : layout_dependent_results)
357 original_type[idx] = op->getResult(idx).getType();
358
359 SmallVector<int64_t, 8> permutation;
360
361 auto attr = permutation_op.value().cast<DenseElementsAttr>();
362 for (const auto& value : attr.getValues<APInt>())
363 permutation.push_back(value.getSExtValue());
364
365 // Check if we can fold transpose into the operation.
366 if (fold_operands && fold_transpose_in_ops) {
367 SmallVector<int64_t, 8> permutation;
368
369 auto attr = permutation_op.value().cast<DenseElementsAttr>();
370 for (const auto& value : attr.getValues<APInt>())
371 permutation.push_back(value.getSExtValue());
372
373 if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
374 }
375
376 // At this point we checked that we can safely move Transpose node after
377 // `op`, bypass all operands transposes, and transpose op results.
378 Location loc = op->getLoc();
379
380 // Move constant op defining result permutation to the beginning of the block.
381 permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
382
383 // Bypass Transpose nodes for layout dependent operands.
384 for (unsigned idx : layout_dependent_operands) {
385 OpOperand& operand = op->getOpOperand(idx);
386 TransposeOp transpose =
387 dyn_cast<TransposeOp>(operand.get().getDefiningOp());
388 operand.set(transpose.getOperand(0));
389 }
390
391 // Maybe add Transpose nodes for layout dependent results
392 // (or reuse existing transposes).
393 OpBuilder builder(op);
394 builder.setInsertionPointAfter(op);
395
396 for (unsigned idx : layout_dependent_results) {
397 OpResult result = op->getResult(idx);
398
399 // If the op is layout agnostic, the new result type can be generated by
400 // reverting `permutation`. Otherwise, operations with custom folding will
401 // update the result type in `FoldOperandsPermutation`.
402 if (layout_agnostic)
403 result.setType(ReversePermuteShapedType(
404 result.getType().cast<ShapedType>(), permutation));
405
406 // Try to push transpose further down.
407 for (Operation* user : result.getUsers()) {
408 if (!llvm::isa<TransposeOp>(user)) work_list->push_back(user);
409 }
410
411 // Try to reuse operand transposes.
412 TransposeOp transpose;
413 if (!transpose_ops.empty()) {
414 transpose = transpose_ops.pop_back_val();
415 transpose.getOperation()->moveBefore(op->getNextNode());
416 transpose.setOperand(0, result);
417 transpose.setOperand(1, permutation_op);
418 transpose.getResult().setType(original_type[idx]);
419 } else {
420 transpose = builder.create<TransposeOp>(loc, result, permutation_op);
421 }
422
423 // Forward all users to the transpose operation.
424 result.replaceAllUsesWith(transpose);
425 transpose.setOperand(0, result);
426 }
427
428 // Remove unused transpose operations.
429 while (!transpose_ops.empty()) {
430 TransposeOp transpose = transpose_ops.pop_back_val();
431 transpose.erase();
432 }
433 }
434
runOnOperation()435 void MoveTransposesPass::runOnOperation() {
436 func::FuncOp func = getOperation();
437
438 SmallVector<Operation*, 8> work_list;
439
440 func.walk([&](TransposeOp transpose) {
441 if (direction_ == MoveTransposeDirection::kBegin) {
442 // Try to push transpose before the operand operation.
443 for (auto operand : transpose.getOperands()) {
444 if (auto op = operand.getDefiningOp()) work_list.push_back(op);
445 }
446 } else {
447 // Try to push transpose after the user operation.
448 for (Operation* user : transpose.y().getUsers()) {
449 if (!llvm::isa<TransposeOp>(user)) work_list.push_back(user);
450 }
451 }
452 });
453
454 while (!work_list.empty()) {
455 Operation* op = work_list.pop_back_val();
456 if (direction_ == MoveTransposeDirection::kBegin) {
457 MoveTransposeBefore(op, &work_list);
458 } else if (direction_ == MoveTransposeDirection::kEnd) {
459 MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
460 }
461 }
462
463 func.walk([&](TransposeOp transpose) {
464 OpBuilder builder(transpose);
465 SmallVector<Value, 1> fold_result;
466 if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
467 assert(fold_result.size() == 1);
468 transpose.replaceAllUsesWith(fold_result[0]);
469 }
470 });
471 }
472
473 } // namespace
474
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)475 void CreateLayoutOptimizationPipeline(
476 OpPassManager& pm, // NOLINT - MLIR contract is pass by mutable reference.
477 const LayoutOptimizationPipelineOptions& options) {
478 // Assign optimal layout for layout sensitive ops.
479 pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
480
481 // Move transposes to the beginning of the block and try to fold them.
482 pm.addPass(std::make_unique<MoveTransposesPass>(
483 MoveTransposeDirection::kBegin, !options.skip_fold_transpose_in_ops));
484
485 // Move transposes to the end of the block and try to fold them.
486 pm.addPass(std::make_unique<MoveTransposesPass>(
487 MoveTransposeDirection::kEnd, !options.skip_fold_transpose_in_ops));
488 }
489
CreateLayoutAssignmentPass()490 std::unique_ptr<OperationPass<func::FuncOp>> CreateLayoutAssignmentPass() {
491 // This static is kind of hack, it hooks the pipeline registration for the
492 // command line and piggy-back to the TableGen generated registration code.
493 static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
494 pipeline("tf-layout-optimization",
495 "Assigns optimal data layout to all layout sensitive operations "
496 "and cancel redundant transpose operations.",
497 CreateLayoutOptimizationPipeline);
498 return std::make_unique<LayoutAssignmentPass>();
499 }
500
CreateMoveTransposesPass()501 std::unique_ptr<OperationPass<func::FuncOp>> CreateMoveTransposesPass() {
502 return std::make_unique<MoveTransposesPass>();
503 }
504
505 } // namespace TF
506 } // namespace mlir
507