1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/Types.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
27
28 namespace mlir {
29
30 class MLIRContext;
31
32 namespace TF {
33
34 SmallVector<int64_t, 4> ReversePermutation(ArrayRef<int64_t> permutation);
35
36 SmallVector<int64_t, 4> GetDataFormatPermutation(StringRef from, StringRef to);
37
38 // Shuffle elements in the `attr` according to the permutation. Optional
39 // `inner_size` allows to shuffle array attributes created from rank 2 tensors
40 // on outer dimension only.
41 ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef<int64_t> permutation,
42 int inner_size = 1);
43
44 // Shuffle ranked tensor dimensions according to the permutation.
45 Type ShuffleRankedTensorType(Type type, ArrayRef<int64_t> permutation);
46
47 bool AreCancellablePermutations(DenseIntElementsAttr perm0,
48 DenseIntElementsAttr perm1);
49
50 // Default implementation of `LayoutSensitiveInterface::UpdateDataFormat` for
51 // layout sensitive operations that do not have any additional layout dependent
52 // attributes besides `data_format` string.
53 template <typename Op>
UpdateDataFormat(StringRef data_format,Op * op)54 LogicalResult UpdateDataFormat(StringRef data_format, Op *op) {
55 auto perm = GetDataFormatPermutation(op->data_format(), data_format);
56 if (perm.empty()) return failure();
57
58 // Update data format attribute.
59 (*op)->setAttr("data_format", StringAttr::get(op->getContext(), data_format));
60
61 // Update types for all layout sensitive results.
62 auto layout_sensitive = cast<LayoutSensitiveInterface>(op->getOperation());
63 for (unsigned idx : layout_sensitive.GetLayoutDependentResults()) {
64 OpResult result = op->getOperation()->getResult(idx);
65 result.setType(ShuffleRankedTensorType(result.getType(), perm));
66 }
67
68 return success();
69 }
70
71 // Default implementation for folding operand transpose into the operation.
72 // See `FoldOperandsTransposeInterface::FoldOperandsPermutation`.
73 template <typename Op>
74 LogicalResult FoldOperandsPermutation(
75 ArrayRef<int64_t> permutation, Op *op,
76 ArrayRef<std::pair<StringRef, ArrayAttr>> shuffle_attrs = {}) {
77 MLIRContext *context =
78 (*op)->template getParentOfType<ModuleOp>().getContext();
79
80 // We only support NHWC <-> NCHW permutations.
81 static constexpr std::array<int64_t, 4> kNchwToNhwc = {0, 2, 3, 1};
82 static constexpr std::array<int64_t, 4> kNhwcToNchw = {0, 3, 1, 2};
83
84 // Operation data format after folding `permutation`.
85 StringRef target_data_format = [&]() -> StringRef {
86 if (op->data_format() == "NHWC" && permutation.equals(kNchwToNhwc)) {
87 return "NCHW"; // cancel NCHW->NHWC operand permutation
88 } else if (op->data_format() == "NCHW" && permutation.equals(kNhwcToNchw)) {
89 return "NHWC"; // cancel NHWC->NCHW operand permutation
90 } else {
91 return "";
92 }
93 }();
94 if (target_data_format.empty()) return failure();
95
96 // To fold operand `permutation` into the `op` we need shuffle all layout
97 // dependent attributes and types with a reverse permutation, and change
98 // operation data format to `target_data_format`.
99 //
100 // Example:
101 // %1 = SomeOp(...) {data_format = NHWC}
102 // %2 = Transpose(%1) {permutation = NHWC->NCHW}
103 // %3 = Op(%2) {data_format = NCHW}
104 //
105 // To bypass %2 we have to change data format to shuffle data format from NCHW
106 // to NHWC, which is the reverse of operand permutation (function argument).
107 auto reverse_permutation =
108 GetDataFormatPermutation(op->data_format(), target_data_format);
109 if (reverse_permutation.empty()) return failure();
110
111 (*op)->setAttr("data_format", StringAttr::get(context, target_data_format));
112
113 for (auto pair : shuffle_attrs) {
114 StringRef attr_name = pair.first;
115 ArrayAttr attr_value = pair.second;
116 (*op)->setAttr(attr_name,
117 ShuffleArrayAttr(attr_value, reverse_permutation));
118 }
119
120 auto fold = cast<FoldOperandsTransposeInterface>(op->getOperation());
121 for (unsigned idx : fold.GetLayoutDependentResults()) {
122 OpResult result = op->getOperation()->getResult(idx);
123 result.setType(
124 ShuffleRankedTensorType(result.getType(), reverse_permutation));
125 }
126
127 return success();
128 }
129
130 } // namespace TF
131 } // namespace mlir
132
133 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OPS_LAYOUT_HELPER_H_
134