1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <numeric>
17 
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Casting.h"
20 #include "llvm/Support/MemoryBuffer.h"
21 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
25 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/FileUtilities.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
31 #include "tensorflow/core/lib/io/path.h"
32 
33 namespace mlir {
34 namespace TF {
35 namespace {
36 
37 static constexpr int kTextFileIndex_WholeLine = -2;
38 static constexpr int kTextFileIndex_LineNumber = -1;
39 
40 // InitTextFileToImportPass converts InitializeTableFromTextFileV2Op to the
41 // corresponding LookupTableImportV2Op if possible.
42 class InitTextFileToImportPass
43     : public InitTextFileToImportPassBase<InitTextFileToImportPass> {
44  public:
InitTextFileToImportPass()45   InitTextFileToImportPass() {}
InitTextFileToImportPass(const InitTextFileToImportPass &)46   InitTextFileToImportPass(const InitTextFileToImportPass&) {}
InitTextFileToImportPass(std::string saved_model_dir)47   explicit InitTextFileToImportPass(std::string saved_model_dir) {
48     saved_model_dir_ = saved_model_dir;
49   }
50 
51  private:
52   void runOnOperation() override;
53 };
54 
55 class ConvertInitializeTableFromTextFileV2
56     : public OpRewritePattern<InitializeTableFromTextFileV2Op> {
57  public:
ConvertInitializeTableFromTextFileV2(mlir::MLIRContext * context,StringRef saved_model_dir)58   explicit ConvertInitializeTableFromTextFileV2(mlir::MLIRContext* context,
59                                                 StringRef saved_model_dir)
60       : OpRewritePattern<InitializeTableFromTextFileV2Op>(context),
61         saved_model_dir_(saved_model_dir) {}
62 
matchAndRewrite(InitializeTableFromTextFileV2Op op,PatternRewriter & rewriter) const63   LogicalResult matchAndRewrite(InitializeTableFromTextFileV2Op op,
64                                 PatternRewriter& rewriter) const override {
65     // Now, this pattern matching only supports the following case, which is
66     // commonly used among inference use cases:
67     //
68     // tf.lookup.TextFileInitializer(
69     //   "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
70     //   tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" ")
71     //
72     // In the above case, the delimiter will be not used since the key is just a
73     // whole line and value is a line number.
74     if (op.key_index() != kTextFileIndex_WholeLine ||
75         op.value_index() != kTextFileIndex_LineNumber) {
76       return failure();
77     }
78 
79     // Try to find filename from constant op.
80     DenseStringElementsAttr filename_attr;
81     if (!matchPattern(op.filename().getDefiningOp(),
82                       m_Constant(&filename_attr))) {
83       return failure();
84     }
85 
86     if (filename_attr.getRawStringData().size() != 1) {
87       return failure();
88     }
89     std::string filename = filename_attr.getRawStringData()[0].str();
90 
91     if (!saved_model_dir_.empty()) {
92       filename = tensorflow::io::JoinPath(
93           saved_model_dir_.str(),
94           tensorflow::io::JoinPath("assets",
95                                    tensorflow::io::Basename(filename)));
96     }
97 
98     // Read the content of the file.
99     std::string error_message;
100     auto file = openInputFile(filename, &error_message);
101     if (!file) {
102       return op.emitOpError("failed to open vocabulary file")
103              << " (" << filename << "): " << error_message;
104     }
105 
106     // Splits into lines.
107     SmallVector<StringRef, 8> lines;
108     file->getBuffer().split(lines, "\n", -1, false);
109     // The resize method is used since split operator puts tail value in the end
110     // without splitting the leftovers.
111     if (op.vocab_size() != -1) lines.resize(op.vocab_size());
112 
113     // Map each line to line number, starting from zero.
114     SmallVector<int64_t, 8> line_nums;
115     line_nums.resize(lines.size());
116     std::iota(line_nums.begin(), line_nums.end(), 0);
117 
118     // Create constant ops for keys an values.
119     Value key_constant_tensor = rewriter.create<arith::ConstantOp>(
120         op.getLoc(),
121         DenseStringElementsAttr::get(
122             RankedTensorType::get(static_cast<int64_t>(lines.size()),
123                                   StringType::get(rewriter.getContext())),
124             lines));
125 
126     Value value_constant_tensor = rewriter.create<arith::ConstantOp>(
127         op.getLoc(), rewriter.getI64TensorAttr(line_nums));
128 
129     // Replace the given op with LookupTableImportV2Op.
130     rewriter.create<LookupTableImportV2Op>(op.getLoc(), op.table_handle(),
131                                            key_constant_tensor,
132                                            value_constant_tensor);
133     rewriter.eraseOp(op);
134     return success();
135   }
136 
137  private:
138   StringRef saved_model_dir_;
139 };
140 
runOnOperation()141 void InitTextFileToImportPass::runOnOperation() {
142   RewritePatternSet patterns(&getContext());
143   MLIRContext* context = &getContext();
144   func::FuncOp func = getOperation();
145 
146   patterns.add<ConvertInitializeTableFromTextFileV2>(
147       context, StringRef(saved_model_dir_));
148   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
149 }
150 
151 }  // namespace
152 
153 // Replace InitializeTableFromTextFileV2Ops with LookupTableImportV2Ops.
CreateInitTextFileToImportPass(std::string saved_model_dir)154 std::unique_ptr<OperationPass<func::FuncOp>> CreateInitTextFileToImportPass(
155     std::string saved_model_dir) {
156   return std::make_unique<InitTextFileToImportPass>(saved_model_dir);
157 }
158 
159 }  // namespace TF
160 }  // namespace mlir
161