xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Pass/PassManager.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Transforms/Passes.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
35 
36 #define DEBUG_TYPE "tf-layout-optimization"
37 
38 namespace mlir {
39 namespace TF {
40 
41 namespace {
42 
43 // Helper method that returns an op from 'transpose_ops' that match criteria
44 // for an 'operand' and 'permutation'
ReuseExistingTranspose(const OpOperand * operand,const SmallVector<int64_t,4> & permutation,Operation * op,ConstOp permutation_op,SmallVector<TransposeOp,2> * transpose_ops)45 TransposeOp ReuseExistingTranspose(const OpOperand* operand,
46                                    const SmallVector<int64_t, 4>& permutation,
47                                    Operation* op, ConstOp permutation_op,
48                                    SmallVector<TransposeOp, 2>* transpose_ops) {
49   for (auto it = transpose_ops->begin(); it != transpose_ops->end(); ++it) {
50     auto tranpose_op = *it;
51     for (auto tranpose_operand : tranpose_op.getOperands()) {
52       auto ranked_tranpose_type =
53           tranpose_operand.getType().dyn_cast_or_null<RankedTensorType>();
54       if (!ranked_tranpose_type) continue;
55       if (ranked_tranpose_type.getRank() == permutation.size() &&
56           operand->get().getType() ==
57               ShuffleRankedTensorType(ranked_tranpose_type, permutation)) {
58         TransposeOp transpose = tranpose_op;
59         transpose.getOperation()->moveBefore(op);
60         transpose.setOperand(0, operand->get());
61         transpose.setOperand(1, permutation_op);
62         transpose_ops->erase(it);
63         return transpose;
64       }
65     }
66   }
67   return nullptr;
68 }
69 
70 // LayoutAssignmentPass assigns optimal data layout (data format) for all
71 // layout sensitive operations.
72 class LayoutAssignmentPass
73     : public LayoutAssignmentPassBase<LayoutAssignmentPass> {
74  public:
75   LayoutAssignmentPass() = default;
LayoutAssignmentPass(const std::string & force_data_format)76   explicit LayoutAssignmentPass(const std::string& force_data_format) {
77     force_data_format_ = force_data_format;
78   }
79 
LayoutAssignmentPass(const LayoutAssignmentPass & pass)80   LayoutAssignmentPass(const LayoutAssignmentPass& pass) {}
81 
82   void runOnOperation() final;
83 };
84 
85 // MoveTransposesPass moves all Transpose ops to the beginning or to the end of
86 // the basic block where they are defined. This will allow canonicalzer to
87 // delete redundant transposes.
88 class MoveTransposesPass : public MoveTransposesPassBase<MoveTransposesPass> {
89  public:
90   MoveTransposesPass() = default;
MoveTransposesPass(MoveTransposeDirection direction,bool fold_transpose_in_ops)91   explicit MoveTransposesPass(MoveTransposeDirection direction,
92                               bool fold_transpose_in_ops) {
93     this->direction_ = direction;
94     this->fold_transpose_in_ops_ = fold_transpose_in_ops;
95   }
MoveTransposesPass(const MoveTransposesPass & pass)96   MoveTransposesPass(const MoveTransposesPass& pass) {}
97 
98   void runOnOperation() final;
99 };
100 
101 using Permutation = SmallVector<int64_t, 4>;
102 
runOnOperation()103 void LayoutAssignmentPass::runOnOperation() {
104   func::FuncOp func = getOperation();
105 
106   // Get runtime devices information from the closest parent module.
107   RuntimeDevices devices;
108   if (failed(::tensorflow::GetDevicesFromOp(func->getParentOfType<ModuleOp>(),
109                                             &devices)))
110     return signalPassFailure();
111 
112   // If there is no runtime device information and data format is not explicitly
113   // forced, there is nothing to do.
114   if (devices.NumDevices() == 0 && force_data_format_.empty()) return;
115 
116   func.walk([&](LayoutSensitiveInterface layout_sensitive_interface) {
117     // Get desired op data format.
118     StringRef target_data_format = force_data_format_;
119     if (target_data_format.empty()) {
120       target_data_format = layout_sensitive_interface.GetOptimalLayout(devices);
121     }
122 
123     // Skip ops that already use target data format.
124     auto data_format = layout_sensitive_interface.data_format();
125     if (data_format == target_data_format) return;
126 
127     // Transpose arguments into the target data format.
128     Permutation args_permutation =
129         GetDataFormatPermutation(data_format, target_data_format);
130 
131     // Transpose results back to the original data format.
132     Permutation res_permutation =
133         GetDataFormatPermutation(target_data_format, data_format);
134 
135     if (args_permutation.empty() || res_permutation.empty()) return;
136 
137     mlir::Operation* op = layout_sensitive_interface.getOperation();
138     Location loc = op->getLoc();
139     OpBuilder builder = OpBuilder::atBlockEnd(op->getBlock());
140 
141     auto perm_attr = [&](Permutation permutation) -> DenseIntElementsAttr {
142       auto perm_ty = RankedTensorType::get({4}, builder.getIntegerType(64));
143       return DenseIntElementsAttr::get(perm_ty, permutation);
144     };
145 
146     // Change operation data format.
147     if (failed(layout_sensitive_interface.UpdateDataFormat(target_data_format)))
148       return;
149 
150     // Permute arguments into the target data format.
151     builder.setInsertionPoint(op);
152     auto arg_perm = builder.create<ConstOp>(loc, perm_attr(args_permutation));
153 
154     for (int64_t arg : layout_sensitive_interface.GetLayoutDependentArgs()) {
155       op->setOperand(
156           arg, builder.create<TransposeOp>(loc, op->getOperand(arg), arg_perm));
157     }
158 
159     // Permute results back to the original data format.
160     builder.setInsertionPointAfter(op);
161     auto res_perm = builder.create<ConstOp>(loc, perm_attr(res_permutation));
162 
163     for (int64_t res : layout_sensitive_interface.GetLayoutDependentResults()) {
164       OpResult result = op->getResult(res);
165 
166       auto transposed_res = builder.create<TransposeOp>(loc, result, res_perm);
167       result.replaceAllUsesWith(transposed_res);
168       transposed_res.setOperand(0, result);
169     }
170   });
171 }
172 
173 // Move Transpose operations that permute `op` results before the `op`.
MoveTransposeBefore(Operation * op,SmallVector<Operation *,8> * work_list)174 void MoveTransposeBefore(Operation* op, SmallVector<Operation*, 8>* work_list) {
175   // TODO(ezhulenev): Move transpose across layout sensitive operations.
176   if (!op->hasTrait<OpTrait::TF::LayoutAgnostic>()) return;
177 
178   // Transpose operations that use operation results.
179   SmallVector<TransposeOp, 2> transpose_ops;
180 
181   // Constant operation that defines permutation indices for result transposes.
182   ConstOp permutation_op;
183 
184   // All operation results must be used by transpose operations with the same
185   // permutation indices.
186   for (OpResult result : op->getResults()) {
187     for (Operation* user : result.getUsers()) {
188       // Result user must be a transpose operation.
189       TransposeOp transpose = dyn_cast<TransposeOp>(user);
190       if (!transpose) return;
191 
192       // With permutation defined by constant operation.
193       ConstOp perm =
194           dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
195       if (!perm) return;
196 
197       // With the same permutation indices.
198       auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
199       if (!dense_elem_attr) return;
200 
201       if (!permutation_op) permutation_op = perm;
202 
203       // Check that permutation matches for all result transposes.
204       if (perm.value() != permutation_op.value()) return;
205 
206       // Add a transpose operation for later reuse.
207       transpose_ops.push_back(transpose);
208     }
209   }
210 
211   // Nothing to do here.
212   if (!permutation_op || transpose_ops.empty()) return;
213   SmallVector<int64_t, 4> permutation;
214   auto perm_attr = permutation_op.value().cast<DenseElementsAttr>();
215   for (const auto& value : perm_attr.getValues<APInt>())
216     permutation.push_back(value.getSExtValue());
217 
218   // We want to make sure the shape of the operand equals the transposed shape.
219   // mismatch can happen if 'op' supports broadcasting and the operands have
220   // different ranks.
221   if (op->hasTrait<OpTrait::ResultsBroadcastableShape>()) {
222     auto transpose_op = *transpose_ops.begin();
223     auto result_type =
224         transpose_op.getResult().getType().dyn_cast_or_null<ShapedType>();
225     auto is_valid_move =
226         llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool {
227           auto operand_type = operand.getType().dyn_cast_or_null<ShapedType>();
228           return result_type && operand_type && result_type.hasRank() &&
229                  operand_type.hasRank() &&
230                  result_type.getRank() == operand_type.getRank();
231         });
232     if (!is_valid_move) return;
233   }
234 
235   // At this point we checked that we can safely move Transpose node before
236   // `op`, and bypass all result transposes.
237   Location loc = op->getLoc();
238 
239   // Move constant op defining result permutation to the beginning of the block.
240   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
241 
242   // Bypass Transpose nodes for all results.
243   for (OpResult result : op->getResults()) {
244     result.setType(cast<TransposeOp>(*result.getUsers().begin()).y().getType());
245     for (Operation* transpose : result.getUsers()) {
246       transpose->getResult(0).replaceAllUsesWith(result);
247     }
248   }
249 
250   // Maybe add a Transpose node for all operands (or reuse existing transposes).
251   OpBuilder builder(op);
252   builder.setInsertionPoint(op);
253 
254   for (OpOperand& operand : op->getOpOperands()) {
255     // Try to push transpose further up.
256     if (Operation* operand_op = operand.get().getDefiningOp())
257       work_list->push_back(operand_op);
258 
259     // Try to reuse result transposes.
260     TransposeOp transpose = ReuseExistingTranspose(
261         &operand, permutation, op, permutation_op, &transpose_ops);
262     // If no transpose available for using, create new one.
263     if (!transpose)
264       transpose =
265           builder.create<TransposeOp>(loc, operand.get(), permutation_op);
266 
267     operand.set(transpose);
268   }
269 
270   // Remove unused transpose operations.
271   while (!transpose_ops.empty()) {
272     TransposeOp transpose = transpose_ops.pop_back_val();
273     transpose.erase();
274   }
275 }
276 
277 // Revert the permutation applied in `type`.
ReversePermuteShapedType(mlir::ShapedType type,ArrayRef<int64_t> permutation)278 static mlir::ShapedType ReversePermuteShapedType(
279     mlir::ShapedType type, ArrayRef<int64_t> permutation) {
280   if (!type.hasRank()) return type;
281 
282   auto shape = type.getShape();
283   SmallVector<int64_t, 4> new_shape(shape.size());
284 
285   for (int i = 0; i < permutation.size(); ++i) {
286     int64_t index = permutation[i];
287     assert(index < shape.size());
288     new_shape[index] = shape[i];
289   }
290 
291   return type.clone(new_shape);
292 }
293 
294 // Move Transpose operations that permute `op` operands after the `op`.
MoveTransposeAfter(Operation * op,SmallVector<Operation *,8> * work_list,bool fold_transpose_in_ops)295 void MoveTransposeAfter(Operation* op, SmallVector<Operation*, 8>* work_list,
296                         bool fold_transpose_in_ops) {
297   // Indices of operands and results that depend on data layout.
298   SmallVector<unsigned, 4> layout_dependent_operands;
299   SmallVector<unsigned, 4> layout_dependent_results;
300 
301   auto fold_operands = dyn_cast<FoldOperandsTransposeInterface>(op);
302   bool layout_agnostic = op->hasTrait<OpTrait::TF::LayoutAgnostic>();
303 
304   if (fold_operands && fold_transpose_in_ops) {
305     layout_dependent_operands = fold_operands.GetLayoutDependentArgs();
306     layout_dependent_results = fold_operands.GetLayoutDependentResults();
307 
308   } else if (layout_agnostic) {
309     // For layout agnostic operation (e.g. element wise operations) all operands
310     // and results must have the same data layout.
311     for (unsigned i = 0; i < op->getNumOperands(); ++i)
312       layout_dependent_operands.push_back(i);
313     for (unsigned i = 0; i < op->getNumResults(); ++i)
314       layout_dependent_results.push_back(i);
315   }
316 
317   // Transpose operations that are operands of the `op`.
318   SmallVector<TransposeOp, 2> transpose_ops;
319 
320   // Constant operation that defines permutation indices for operand transposes.
321   ConstOp permutation_op;
322 
323   // Layout dependent operands must be transpose operations with the same
324   // permutation indices.
325   for (unsigned idx : layout_dependent_operands) {
326     OpOperand& operand = op->getOpOperand(idx);
327 
328     // Operand must be defined by a transpose op.
329     TransposeOp transpose =
330         dyn_cast_or_null<TransposeOp>(operand.get().getDefiningOp());
331     if (!transpose) return;
332 
333     // With permutation defined by constant operation.
334     ConstOp perm =
335         dyn_cast_or_null<ConstOp>(transpose.getOperand(1).getDefiningOp());
336     if (!perm) return;
337 
338     // With the same permutation indices.
339     auto dense_elem_attr = perm.value().dyn_cast<DenseElementsAttr>();
340     if (!dense_elem_attr) return;
341 
342     if (!permutation_op) permutation_op = perm;
343 
344     // Check that permutation matches for all result transposes.
345     if (perm.value() != permutation_op.value()) return;
346 
347     // Add a transpose operation for later reuse only if it's used once.
348     if (transpose.getResult().hasOneUse()) transpose_ops.push_back(transpose);
349   }
350 
351   // Nothing to do here.
352   if (!permutation_op) return;
353 
354   // All results after transpose must preserve the original result type.
355   SmallVector<Type, 4> original_type(op->getNumResults());
356   for (unsigned idx : layout_dependent_results)
357     original_type[idx] = op->getResult(idx).getType();
358 
359   SmallVector<int64_t, 8> permutation;
360 
361   auto attr = permutation_op.value().cast<DenseElementsAttr>();
362   for (const auto& value : attr.getValues<APInt>())
363     permutation.push_back(value.getSExtValue());
364 
365   // Check if we can fold transpose into the operation.
366   if (fold_operands && fold_transpose_in_ops) {
367     SmallVector<int64_t, 8> permutation;
368 
369     auto attr = permutation_op.value().cast<DenseElementsAttr>();
370     for (const auto& value : attr.getValues<APInt>())
371       permutation.push_back(value.getSExtValue());
372 
373     if (failed(fold_operands.FoldOperandsPermutation(permutation))) return;
374   }
375 
376   // At this point we checked that we can safely move Transpose node after
377   // `op`, bypass all operands transposes, and transpose op results.
378   Location loc = op->getLoc();
379 
380   // Move constant op defining result permutation to the beginning of the block.
381   permutation_op.getOperation()->moveBefore(&op->getBlock()->front());
382 
383   // Bypass Transpose nodes for layout dependent operands.
384   for (unsigned idx : layout_dependent_operands) {
385     OpOperand& operand = op->getOpOperand(idx);
386     TransposeOp transpose =
387         dyn_cast<TransposeOp>(operand.get().getDefiningOp());
388     operand.set(transpose.getOperand(0));
389   }
390 
391   // Maybe add Transpose nodes for layout dependent results
392   // (or reuse existing transposes).
393   OpBuilder builder(op);
394   builder.setInsertionPointAfter(op);
395 
396   for (unsigned idx : layout_dependent_results) {
397     OpResult result = op->getResult(idx);
398 
399     // If the op is layout agnostic, the new result type can be generated by
400     // reverting `permutation`. Otherwise, operations with custom folding will
401     // update the result type in `FoldOperandsPermutation`.
402     if (layout_agnostic)
403       result.setType(ReversePermuteShapedType(
404           result.getType().cast<ShapedType>(), permutation));
405 
406     // Try to push transpose further down.
407     for (Operation* user : result.getUsers()) {
408       if (!llvm::isa<TransposeOp>(user)) work_list->push_back(user);
409     }
410 
411     // Try to reuse operand transposes.
412     TransposeOp transpose;
413     if (!transpose_ops.empty()) {
414       transpose = transpose_ops.pop_back_val();
415       transpose.getOperation()->moveBefore(op->getNextNode());
416       transpose.setOperand(0, result);
417       transpose.setOperand(1, permutation_op);
418       transpose.getResult().setType(original_type[idx]);
419     } else {
420       transpose = builder.create<TransposeOp>(loc, result, permutation_op);
421     }
422 
423     // Forward all users to the transpose operation.
424     result.replaceAllUsesWith(transpose);
425     transpose.setOperand(0, result);
426   }
427 
428   // Remove unused transpose operations.
429   while (!transpose_ops.empty()) {
430     TransposeOp transpose = transpose_ops.pop_back_val();
431     transpose.erase();
432   }
433 }
434 
runOnOperation()435 void MoveTransposesPass::runOnOperation() {
436   func::FuncOp func = getOperation();
437 
438   SmallVector<Operation*, 8> work_list;
439 
440   func.walk([&](TransposeOp transpose) {
441     if (direction_ == MoveTransposeDirection::kBegin) {
442       // Try to push transpose before the operand operation.
443       for (auto operand : transpose.getOperands()) {
444         if (auto op = operand.getDefiningOp()) work_list.push_back(op);
445       }
446     } else {
447       // Try to push transpose after the user operation.
448       for (Operation* user : transpose.y().getUsers()) {
449         if (!llvm::isa<TransposeOp>(user)) work_list.push_back(user);
450       }
451     }
452   });
453 
454   while (!work_list.empty()) {
455     Operation* op = work_list.pop_back_val();
456     if (direction_ == MoveTransposeDirection::kBegin) {
457       MoveTransposeBefore(op, &work_list);
458     } else if (direction_ == MoveTransposeDirection::kEnd) {
459       MoveTransposeAfter(op, &work_list, fold_transpose_in_ops_);
460     }
461   }
462 
463   func.walk([&](TransposeOp transpose) {
464     OpBuilder builder(transpose);
465     SmallVector<Value, 1> fold_result;
466     if (succeeded(builder.tryFold(transpose.getOperation(), fold_result))) {
467       assert(fold_result.size() == 1);
468       transpose.replaceAllUsesWith(fold_result[0]);
469     }
470   });
471 }
472 
473 }  // namespace
474 
CreateLayoutOptimizationPipeline(OpPassManager & pm,const LayoutOptimizationPipelineOptions & options)475 void CreateLayoutOptimizationPipeline(
476     OpPassManager& pm,  // NOLINT - MLIR contract is pass by mutable reference.
477     const LayoutOptimizationPipelineOptions& options) {
478   // Assign optimal layout for layout sensitive ops.
479   pm.addPass(std::make_unique<LayoutAssignmentPass>(options.force_data_format));
480 
481   // Move transposes to the beginning of the block and try to fold them.
482   pm.addPass(std::make_unique<MoveTransposesPass>(
483       MoveTransposeDirection::kBegin, !options.skip_fold_transpose_in_ops));
484 
485   // Move transposes to the end of the block and try to fold them.
486   pm.addPass(std::make_unique<MoveTransposesPass>(
487       MoveTransposeDirection::kEnd, !options.skip_fold_transpose_in_ops));
488 }
489 
CreateLayoutAssignmentPass()490 std::unique_ptr<OperationPass<func::FuncOp>> CreateLayoutAssignmentPass() {
491   // This static is kind of hack, it hooks the pipeline registration for the
492   // command line and piggy-back to the TableGen generated registration code.
493   static mlir::PassPipelineRegistration<LayoutOptimizationPipelineOptions>
494       pipeline("tf-layout-optimization",
495                "Assigns optimal data layout to all layout sensitive operations "
496                "and cancel redundant transpose operations.",
497                CreateLayoutOptimizationPipeline);
498   return std::make_unique<LayoutAssignmentPass>();
499 }
500 
CreateMoveTransposesPass()501 std::unique_ptr<OperationPass<func::FuncOp>> CreateMoveTransposesPass() {
502   return std::make_unique<MoveTransposesPass>();
503 }
504 
505 }  // namespace TF
506 }  // namespace mlir
507