1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "llvm/ADT/None.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/MLIRContext.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "mlir/IR/PatternMatch.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35
36 namespace mlir {
37 namespace TFL {
38 namespace {
39 #define GEN_PASS_CLASSES
40 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
41
42 // This file has Legalize hash tables pass which is responsible for:
43 // - Converting static hash table ops to the TFLite equivalent ops.
44 //
45 // There are needs to fall back to Flex for the following cases:
46 // - Mutable hash table cases
47 // - Other resource operators consuming a hash table resource tensor
48
49 class LegalizeHashTableOpPattern : public OpRewritePattern<TF::HashTableV2Op> {
50 public:
51 using OpRewritePattern<TF::HashTableV2Op>::OpRewritePattern;
52
matchAndRewrite(TF::HashTableV2Op hashtable_op,PatternRewriter & rewriter) const53 LogicalResult matchAndRewrite(TF::HashTableV2Op hashtable_op,
54 PatternRewriter& rewriter) const override {
55 auto output_type = RankedTensorType::get(
56 {1}, TF::ResourceType::get(rewriter.getContext()));
57
58 // Hash the shared name to generate integer hash table id. The TFLite
59 // native resource design is based on integer keys to identify the
60 // corresponding resource objects.
61 auto table_id =
62 static_cast<int32_t>(::llvm::hash_value(hashtable_op.shared_name()));
63 auto key_dtype = hashtable_op.key_dtype();
64 auto value_dtype = hashtable_op.value_dtype();
65
66 rewriter.replaceOpWithNewOp<TFL::HashtableOp>(
67 hashtable_op, output_type, table_id, key_dtype, value_dtype);
68 return success();
69 }
70 };
71
72 class LegalizeHashTableFindOpPattern
73 : public OpRewritePattern<TF::LookupTableFindV2Op> {
74 public:
75 using OpRewritePattern<TF::LookupTableFindV2Op>::OpRewritePattern;
76
matchAndRewrite(TF::LookupTableFindV2Op find_op,PatternRewriter & rewriter) const77 LogicalResult matchAndRewrite(TF::LookupTableFindV2Op find_op,
78 PatternRewriter& rewriter) const override {
79 auto handle_op = find_op.table_handle().getDefiningOp();
80 if (handle_op == nullptr) return failure();
81 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
82 if (hashtable_op == nullptr) return failure();
83 rewriter.replaceOpWithNewOp<TFL::HashtableFindOp>(
84 find_op, find_op->getResultTypes(), find_op.table_handle(),
85 find_op.keys(), find_op.default_value());
86 return success();
87 }
88 };
89
90 class LegalizeHashTableImportOpPattern
91 : public OpRewritePattern<TF::LookupTableImportV2Op> {
92 public:
93 using OpRewritePattern<TF::LookupTableImportV2Op>::OpRewritePattern;
94
matchAndRewrite(TF::LookupTableImportV2Op import_op,PatternRewriter & rewriter) const95 LogicalResult matchAndRewrite(TF::LookupTableImportV2Op import_op,
96 PatternRewriter& rewriter) const override {
97 auto handle_op = import_op.table_handle().getDefiningOp();
98 if (handle_op == nullptr) return failure();
99 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
100 if (hashtable_op == nullptr) return failure();
101 rewriter.replaceOpWithNewOp<TFL::HashtableImportOp>(
102 import_op, import_op->getResultTypes(), import_op.table_handle(),
103 import_op.keys(), import_op.values());
104 return success();
105 }
106 };
107
108 class LegalizeHashTableSizeOpPattern
109 : public OpRewritePattern<TF::LookupTableSizeV2Op> {
110 public:
111 using OpRewritePattern<TF::LookupTableSizeV2Op>::OpRewritePattern;
112
matchAndRewrite(TF::LookupTableSizeV2Op size_op,PatternRewriter & rewriter) const113 LogicalResult matchAndRewrite(TF::LookupTableSizeV2Op size_op,
114 PatternRewriter& rewriter) const override {
115 auto handle_op = size_op.table_handle().getDefiningOp();
116 if (handle_op == nullptr) return failure();
117 auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
118 if (hashtable_op == nullptr) return failure();
119 rewriter.replaceOpWithNewOp<TFL::HashtableSizeOp>(
120 size_op, size_op->getResultTypes(), size_op.table_handle());
121 return success();
122 }
123 };
124
125 template <typename T>
GetAllOps(mlir::ModuleOp * module)126 std::vector<T> GetAllOps(mlir::ModuleOp* module) {
127 std::vector<T> ops;
128 module->walk([&](T op) { ops.emplace_back(op); });
129 return ops;
130 }
131
checkWhetherGraphHasValidStaticLookupTables(ModuleOp module)132 bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) {
133 auto hashtables = GetAllOps<TF::HashTableV2Op>(&module);
134 // No needs to run the legalization patterns.
135 if (hashtables.empty()) {
136 return false;
137 }
138
139 for (auto hashtable : hashtables) {
140 auto key_dtype = hashtable.key_dtype();
141 auto value_dtype = hashtable.value_dtype();
142
143 // Only allow string -> int64 and int64 -> string mappings due to kernel
144 // capability.
145 if (!((key_dtype.isa<TF::StringType>() && value_dtype.isa<IntegerType>() &&
146 value_dtype.cast<IntegerType>().getWidth() == 64) ||
147 (value_dtype.isa<TF::StringType>() && key_dtype.isa<IntegerType>() &&
148 key_dtype.cast<IntegerType>().getWidth() == 64))) {
149 return false;
150 }
151
152 for (auto& use : hashtable->getUses()) {
153 Operation* user = use.getOwner();
154
155 // Allow consuming hash table ops that can be covered by TensorFlow Lite
156 // hash table kernels.
157 if (auto find_op = llvm::dyn_cast<TF::LookupTableFindV2Op>(user))
158 continue;
159 if (auto import_op = llvm::dyn_cast<TF::LookupTableImportV2Op>(user))
160 continue;
161 if (auto size_op = llvm::dyn_cast<TF::LookupTableSizeV2Op>(user))
162 continue;
163
164 return false;
165 }
166 }
167 return true;
168 }
169
170 // Pass which legalizes TF hash tables only when they are covered by the
171 // TensorFlow Lite hash table kernels.
172 class LegalizeHashTablesPass
173 : public LegalizeHashTablesPassBase<LegalizeHashTablesPass> {
174 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeHashTablesPass)175 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeHashTablesPass)
176
177 void runOnOperation() override {
178 auto module = getOperation();
179
180 if (!checkWhetherGraphHasValidStaticLookupTables(module)) {
181 return;
182 }
183
184 RewritePatternSet patterns(&getContext());
185 patterns
186 .add<LegalizeHashTableOpPattern, LegalizeHashTableFindOpPattern,
187 LegalizeHashTableImportOpPattern, LegalizeHashTableSizeOpPattern>(
188 &getContext());
189 if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
190 signalPassFailure();
191 return;
192 }
193 }
194 };
195
196 } // namespace
197
CreateLegalizeHashTablesPass()198 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHashTablesPass() {
199 return std::make_unique<LegalizeHashTablesPass>();
200 }
201
202 } // namespace TFL
203 } // namespace mlir
204