1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <numeric>
17
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/OperationSupport.h" // from @llvm-project
25 #include "mlir/IR/PatternMatch.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/FileUtilities.h" // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
31 #include "tensorflow/core/lib/io/path.h"
32
33 namespace mlir {
34 namespace TF {
35 namespace {
36
37 static constexpr int kTextFileIndex_WholeLine = -2;
38 static constexpr int kTextFileIndex_LineNumber = -1;
39
40 // InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
41 // corresponding LookupTableImportV2Op if possible.
42 class InitTextFileToImportPass
43 : public InitTextFileToImportPassBase<InitTextFileToImportPass> {
44 public:
InitTextFileToImportPass()45 InitTextFileToImportPass() {}
InitTextFileToImportPass(const InitTextFileToImportPass &)46 InitTextFileToImportPass(const InitTextFileToImportPass&) {}
InitTextFileToImportPass(std::string saved_model_dir)47 explicit InitTextFileToImportPass(std::string saved_model_dir) {
48 saved_model_dir_ = saved_model_dir;
49 }
50
51 private:
52 void runOnOperation() override;
53 };
54
55 class ConvertInitializeTableFromTextFileV2
56 : public OpRewritePattern<InitializeTableFromTextFileV2Op> {
57 public:
ConvertInitializeTableFromTextFileV2(mlir::MLIRContext * context,StringRef saved_model_dir)58 explicit ConvertInitializeTableFromTextFileV2(mlir::MLIRContext* context,
59 StringRef saved_model_dir)
60 : OpRewritePattern<InitializeTableFromTextFileV2Op>(context),
61 saved_model_dir_(saved_model_dir) {}
62
matchAndRewrite(InitializeTableFromTextFileV2Op op,PatternRewriter & rewriter) const63 LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
64 PatternRewriter& rewriter) const override {
65 // Now, this pattern matching only supports the following case, which is
66 // commonly used among inference use cases:
67 //
68 // tf.lookup.TextFileInitializer(
69 // "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
70 // tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
71 //
72 // In the above case, the delimiter will be not used since the key is just a
73 // whole line and value is a line number.
74 if (op.key_index() != kTextFileIndex_WholeLine ||
75 op.value_index() != kTextFileIndex_LineNumber) {
76 return failure();
77 }
78
79 // Try to find filename from constant op.
80 DenseStringElementsAttr filename_attr;
81 if (!matchPattern(op.filename().getDefiningOp(),
82 m_Constant(&filename_attr))) {
83 return failure();
84 }
85
86 if (filename_attr.getRawStringData().size() != 1) {
87 return failure();
88 }
89 std::string filename = filename_attr.getRawStringData()[0].str();
90
91 if (!saved_model_dir_.empty()) {
92 filename = tensorflow::io::JoinPath(
93 saved_model_dir_.str(),
94 tensorflow::io::JoinPath("assets",
95 tensorflow::io::Basename(filename)));
96 }
97
98 // Read the content of the file.
99 std::string error_message;
100 auto file = openInputFile(filename, &error_message);
101 if (!file) {
102 return op.emitOpError("failed to open vocabulary file")
103 << " (" << filename << "): " << error_message;
104 }
105
106 // Splits into lines.
107 SmallVector<StringRef, 8> lines;
108 file->getBuffer().split(lines, "\n", -1, false);
109 // The resize method is used since split operator puts tail value in the end
110 // without splitting the leftovers.
111 if (op.vocab_size() != -1) lines.resize(op.vocab_size());
112
113 // Map each line to line number, starting from zero.
114 SmallVector<int64_t, 8> line_nums;
115 line_nums.resize(lines.size());
116 std::iota(line_nums.begin(), line_nums.end(), 0);
117
118 // Create constant ops for keys an values.
119 Value key_constant_tensor = rewriter.create<arith::ConstantOp>(
120 op.getLoc(),
121 DenseStringElementsAttr::get(
122 RankedTensorType::get(static_cast<int64_t>(lines.size()),
123 StringType::get(rewriter.getContext())),
124 lines));
125
126 Value value_constant_tensor = rewriter.create<arith::ConstantOp>(
127 op.getLoc(), rewriter.getI64TensorAttr(line_nums));
128
129 // Replace the given op with LookupTableImportV2Op.
130 rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
131 key_constant_tensor,
132 value_constant_tensor);
133 rewriter.eraseOp(op);
134 return success();
135 }
136
137 private:
138 StringRef saved_model_dir_;
139 };
140
runOnOperation()141 void InitTextFileToImportPass::runOnOperation() {
142 RewritePatternSet patterns(&getContext());
143 MLIRContext* context = &getContext();
144 func::FuncOp func = getOperation();
145
146 patterns.add<ConvertInitializeTableFromTextFileV2>(
147 context, StringRef(saved_model_dir_));
148 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
149 }
150
151 } // namespace
152
153 // Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
CreateInitTextFileToImportPass(std::string saved_model_dir)154 std::unique_ptr<OperationPass<func::FuncOp>> CreateInitTextFileToImportPass(
155 std::string saved_model_dir) {
156 return std::make_unique<InitTextFileToImportPass>(saved_model_dir);
157 }
158
159 } // namespace TF
160 } // namespace mlir
161