xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "llvm/ADT/None.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/Attributes.h"  // from @llvm-project
23 #include "mlir/IR/Builders.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 
36 namespace mlir {
37 namespace TFL {
38 namespace {
39 #define GEN_PASS_CLASSES
40 #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc"
41 
42 // This file has Legalize hash tables pass which is responsible for:
43 // - Converting static hash table ops to the TFLite equivalent ops.
44 //
45 // There are needs to fall back to Flex for the following cases:
46 // - Mutable hash table cases
47 // - Other resource operators consuming a hash table resource tensor
48 
49 class LegalizeHashTableOpPattern : public OpRewritePattern<TF::HashTableV2Op> {
50  public:
51   using OpRewritePattern<TF::HashTableV2Op>::OpRewritePattern;
52 
matchAndRewrite(TF::HashTableV2Op hashtable_op,PatternRewriter & rewriter) const53   LogicalResult matchAndRewrite(TF::HashTableV2Op hashtable_op,
54                                 PatternRewriter& rewriter) const override {
55     auto output_type = RankedTensorType::get(
56         {1}, TF::ResourceType::get(rewriter.getContext()));
57 
58     // Hash the shared name to generate integer hash table id. The TFLite
59     // native resource design is based on integer keys to identify the
60     // corresponding resource objects.
61     auto table_id =
62         static_cast<int32_t>(::llvm::hash_value(hashtable_op.shared_name()));
63     auto key_dtype = hashtable_op.key_dtype();
64     auto value_dtype = hashtable_op.value_dtype();
65 
66     rewriter.replaceOpWithNewOp<TFL::HashtableOp>(
67         hashtable_op, output_type, table_id, key_dtype, value_dtype);
68     return success();
69   }
70 };
71 
72 class LegalizeHashTableFindOpPattern
73     : public OpRewritePattern<TF::LookupTableFindV2Op> {
74  public:
75   using OpRewritePattern<TF::LookupTableFindV2Op>::OpRewritePattern;
76 
matchAndRewrite(TF::LookupTableFindV2Op find_op,PatternRewriter & rewriter) const77   LogicalResult matchAndRewrite(TF::LookupTableFindV2Op find_op,
78                                 PatternRewriter& rewriter) const override {
79     auto handle_op = find_op.table_handle().getDefiningOp();
80     if (handle_op == nullptr) return failure();
81     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
82     if (hashtable_op == nullptr) return failure();
83     rewriter.replaceOpWithNewOp<TFL::HashtableFindOp>(
84         find_op, find_op->getResultTypes(), find_op.table_handle(),
85         find_op.keys(), find_op.default_value());
86     return success();
87   }
88 };
89 
90 class LegalizeHashTableImportOpPattern
91     : public OpRewritePattern<TF::LookupTableImportV2Op> {
92  public:
93   using OpRewritePattern<TF::LookupTableImportV2Op>::OpRewritePattern;
94 
matchAndRewrite(TF::LookupTableImportV2Op import_op,PatternRewriter & rewriter) const95   LogicalResult matchAndRewrite(TF::LookupTableImportV2Op import_op,
96                                 PatternRewriter& rewriter) const override {
97     auto handle_op = import_op.table_handle().getDefiningOp();
98     if (handle_op == nullptr) return failure();
99     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
100     if (hashtable_op == nullptr) return failure();
101     rewriter.replaceOpWithNewOp<TFL::HashtableImportOp>(
102         import_op, import_op->getResultTypes(), import_op.table_handle(),
103         import_op.keys(), import_op.values());
104     return success();
105   }
106 };
107 
108 class LegalizeHashTableSizeOpPattern
109     : public OpRewritePattern<TF::LookupTableSizeV2Op> {
110  public:
111   using OpRewritePattern<TF::LookupTableSizeV2Op>::OpRewritePattern;
112 
matchAndRewrite(TF::LookupTableSizeV2Op size_op,PatternRewriter & rewriter) const113   LogicalResult matchAndRewrite(TF::LookupTableSizeV2Op size_op,
114                                 PatternRewriter& rewriter) const override {
115     auto handle_op = size_op.table_handle().getDefiningOp();
116     if (handle_op == nullptr) return failure();
117     auto hashtable_op = llvm::dyn_cast<TFL::HashtableOp>(handle_op);
118     if (hashtable_op == nullptr) return failure();
119     rewriter.replaceOpWithNewOp<TFL::HashtableSizeOp>(
120         size_op, size_op->getResultTypes(), size_op.table_handle());
121     return success();
122   }
123 };
124 
125 template <typename T>
GetAllOps(mlir::ModuleOp * module)126 std::vector<T> GetAllOps(mlir::ModuleOp* module) {
127   std::vector<T> ops;
128   module->walk([&](T op) { ops.emplace_back(op); });
129   return ops;
130 }
131 
checkWhetherGraphHasValidStaticLookupTables(ModuleOp module)132 bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) {
133   auto hashtables = GetAllOps<TF::HashTableV2Op>(&module);
134   // No needs to run the legalization patterns.
135   if (hashtables.empty()) {
136     return false;
137   }
138 
139   for (auto hashtable : hashtables) {
140     auto key_dtype = hashtable.key_dtype();
141     auto value_dtype = hashtable.value_dtype();
142 
143     // Only allow string -> int64 and int64 -> string mappings due to kernel
144     // capability.
145     if (!((key_dtype.isa<TF::StringType>() && value_dtype.isa<IntegerType>() &&
146            value_dtype.cast<IntegerType>().getWidth() == 64) ||
147           (value_dtype.isa<TF::StringType>() && key_dtype.isa<IntegerType>() &&
148            key_dtype.cast<IntegerType>().getWidth() == 64))) {
149       return false;
150     }
151 
152     for (auto& use : hashtable->getUses()) {
153       Operation* user = use.getOwner();
154 
155       // Allow consuming hash table ops that can be covered by TensorFlow Lite
156       // hash table kernels.
157       if (auto find_op = llvm::dyn_cast<TF::LookupTableFindV2Op>(user))
158         continue;
159       if (auto import_op = llvm::dyn_cast<TF::LookupTableImportV2Op>(user))
160         continue;
161       if (auto size_op = llvm::dyn_cast<TF::LookupTableSizeV2Op>(user))
162         continue;
163 
164       return false;
165     }
166   }
167   return true;
168 }
169 
170 // Pass which legalizes TF hash tables only when they are covered by the
171 // TensorFlow Lite hash table kernels.
172 class LegalizeHashTablesPass
173     : public LegalizeHashTablesPassBase<LegalizeHashTablesPass> {
174  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeHashTablesPass)175   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LegalizeHashTablesPass)
176 
177   void runOnOperation() override {
178     auto module = getOperation();
179 
180     if (!checkWhetherGraphHasValidStaticLookupTables(module)) {
181       return;
182     }
183 
184     RewritePatternSet patterns(&getContext());
185     patterns
186         .add<LegalizeHashTableOpPattern, LegalizeHashTableFindOpPattern,
187              LegalizeHashTableImportOpPattern, LegalizeHashTableSizeOpPattern>(
188             &getContext());
189     if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
190       signalPassFailure();
191       return;
192     }
193   }
194 };
195 
196 }  // namespace
197 
CreateLegalizeHashTablesPass()198 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeHashTablesPass() {
199   return std::make_unique<LegalizeHashTablesPass>();
200 }
201 
202 }  // namespace TFL
203 }  // namespace mlir
204