xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
17 
18 #include <string>
19 
20 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/None.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/IR/Attributes.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/Location.h"  // from @llvm-project
33 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
34 #include "mlir/IR/Matchers.h"  // from @llvm-project
35 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Support/LLVM.h"  // from @llvm-project
40 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43 
44 namespace mlir {
45 namespace TFL {
46 
47 namespace {
48 
49 constexpr char kNgrams[] = "tftext:Ngrams";
50 constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
51 constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
52 constexpr char kTFImplements[] = "tf._implements";
53 
54 using mlir::TF::FuncAttr;
55 using mlir::TF::StringType;
56 
CustomOption(OpBuilder * builder,const std::string & content)57 inline ConstBytesAttr CustomOption(OpBuilder* builder,
58                                    const std::string& content) {
59   return ConstBytesAttr::get(builder->getContext(),
60                              StringRef(content.data(), content.size()));
61 }
62 
GetInputType(func::FuncOp func,int idx)63 inline TensorType GetInputType(func::FuncOp func, int idx) {
64   return func.getFunctionType().getInput(idx).dyn_cast_or_null<TensorType>();
65 }
66 
GetResultType(func::FuncOp func,int idx)67 inline TensorType GetResultType(func::FuncOp func, int idx) {
68   return func.getFunctionType().getResult(idx).dyn_cast_or_null<TensorType>();
69 }
70 
RankEquals(const TensorType & type,int rank)71 inline bool RankEquals(const TensorType& type, int rank) {
72   return type && type.hasRank() && type.getRank() == rank;
73 }
74 
VerifyWhitespaceTokenizer(func::FuncOp func)75 LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) {
76   // In the case of input tensor with 0 rank.
77   // Whitespace tokenizer generates 1 output:
78   // * String tensor for tokens.
79   //
80   // In the case of 1-D input tensor,
81   // Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
82   // * 1st output is the value of ragged tensor;
83   // * 2nd output is the offset.
84   //
85   // In the case of batched input tesnor,
86   // Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
87   // * 1st output is the value of ragged tensor;
88   // * 2nd output is the inner offset;
89   // * 3rd output is the outer offset.
90   auto input_type = GetInputType(func, 0);
91   if (!input_type || !input_type.getElementType().isa<StringType>() ||
92       !input_type.hasRank()) {
93     return func.emitError() << "Input should be a string tensor";
94   }
95 
96   const std::vector<int> kValidNumOfOutput = {1, 2, 3};
97   if (input_type.getRank() >= kValidNumOfOutput.size()) {
98     return func.emitError()
99            << "Unrecognized input rank: " << input_type.getRank();
100   }
101   if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
102     return func.emitError()
103            << "Expect " << kValidNumOfOutput[input_type.getRank()]
104            << "output(s) when input has rank " << input_type.getRank();
105   }
106 
107   auto value_type = GetResultType(func, 0);
108   if (!RankEquals(value_type, 1) ||
109       !value_type.getElementType().isa<StringType>()) {
110     return func.emitError() << "1st output should be string tensor";
111   }
112   if (func.getNumResults() > 1) {
113     auto offset_type = GetResultType(func, 1);
114     if (!RankEquals(offset_type, 1) ||
115         !offset_type.getElementType().isInteger(64)) {
116       return func.emitError() << "2nd output should be int64 tensor";
117     }
118   }
119   if (func.getNumResults() > 2) {
120     auto offset_type = GetResultType(func, 2);
121     if (!RankEquals(offset_type, 1) ||
122         !offset_type.getElementType().isInteger(64)) {
123       return func.emitError() << "3rd output should be int64 tensor";
124     }
125   }
126 
127   return success();
128 }
129 
ConvertWhitespaceTokenizer(func::FuncOp func,llvm::StringRef api,FuncAttr attr)130 LogicalResult ConvertWhitespaceTokenizer(func::FuncOp func, llvm::StringRef api,
131                                          FuncAttr attr) {
132   func.eraseBody();
133   func.addEntryBlock();
134   func->setAttr(kTFImplements, attr);
135   OpBuilder builder(func.getBody());
136   std::string empty_option_buffer;
137   auto op = builder.create<CustomOp>(
138       func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
139       api, CustomOption(&builder, empty_option_buffer));
140   builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
141   return success();
142 }
143 
VerifyNgrams(func::FuncOp func)144 LogicalResult VerifyNgrams(func::FuncOp func) {
145   // The inputs and outputs should be the same:
146   // * A string tensor for tokens/ragged tensor values.
147   // * Zero or more row_split tensors.
148   constexpr int kValues = 0;
149   constexpr int kRowSplits = 1;
150 
151   if (func.getFunctionType().getInputs().size() !=
152       func.getFunctionType().getResults().size()) {
153     return func.emitError() << "Mismatched number of inputs and outputs.";
154   }
155 
156   int row_splits = func.getFunctionType().getInputs().size() - kRowSplits;
157   if (row_splits == 0) {
158     auto input_values = GetInputType(func, kValues);
159     if (!input_values || !input_values.getElementType().isa<StringType>()) {
160       return func.emitError()
161              << "Input " << kValues << " should be a string tensor";
162     }
163     auto output_values = GetResultType(func, kValues);
164     if (!output_values || !output_values.getElementType().isa<StringType>()) {
165       return func.emitError()
166              << "Output " << kValues << " should be a string tensor";
167     }
168 
169     if (input_values.hasRank() && output_values.hasRank() &&
170         input_values.getRank() != output_values.getRank()) {
171       return func.emitError() << "Input " << kValues << " and output "
172                               << kValues << " should have the same rank";
173     }
174   } else {
175     auto input_values = GetInputType(func, kValues);
176     if (!RankEquals(input_values, 1) ||
177         !input_values.getElementType().isa<StringType>()) {
178       return func.emitError()
179              << "Input " << kValues << " should be a 1D string tensor";
180     }
181     auto output_values = GetResultType(func, kValues);
182     if (!RankEquals(output_values, 1) ||
183         !output_values.getElementType().isa<StringType>()) {
184       return func.emitError()
185              << "Output " << kValues << " should be a 1D string tensor";
186     }
187 
188     for (int i = 0; i < row_splits; ++i) {
189       const int row_index = i + kRowSplits;
190       auto input_row_splits = GetInputType(func, row_index);
191       if (!RankEquals(input_row_splits, 1) ||
192           !input_row_splits.getElementType().isInteger(64)) {
193         return func.emitError()
194                << "Input " << row_index << " should be a 1D int64 tensor";
195       }
196       auto output_row_splits = GetResultType(func, row_index);
197       if (!RankEquals(output_row_splits, 1) ||
198           !output_row_splits.getElementType().isInteger(64)) {
199         return func.emitError()
200                << "Output " << row_index << " should be a 1D int64 tensor";
201       }
202     }
203   }
204 
205   return success();
206 }
207 
CreateNgramsCustomOption(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)208 LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs,
209                                        std::string& custom_option_buffer) {
210   flexbuffers::Builder fbb;
211   size_t start_map = fbb.StartMap();
212 
213   auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
214   if (!width) {
215     return func.emitError() << "'width' attribute is not set or not an integer";
216   }
217   fbb.Int("width", width.getInt());
218 
219   auto string_separator =
220       attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
221   if (!string_separator) {
222     return func.emitError()
223            << "'string_separator' attribute is not set or not a string";
224   }
225   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
226   // strings expect NUL terminated strings.
227   std::string string_separator_str(string_separator.getValue().data(),
228                                    string_separator.getValue().size());
229   fbb.String("string_separator", string_separator_str);
230 
231   auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
232   if (!axis) {
233     return func.emitError() << "'axis' attribute is not set or not an integer";
234   }
235   fbb.Int("axis", axis.getInt());
236 
237   auto reduction_type =
238       attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
239   if (!reduction_type) {
240     return func.emitError()
241            << "'reduction_type' attribute is not set or not a string";
242   }
243   // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
244   // strings expect NUL terminated strings.
245   std::string reduction_type_str(reduction_type.getValue().data(),
246                                  reduction_type.getValue().size());
247   fbb.String("reduction_type", reduction_type_str);
248 
249   fbb.EndMap(start_map);
250   fbb.Finish();
251   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
252   return success();
253 }
254 
ConvertNgrams(func::FuncOp func,llvm::StringRef api,FuncAttr attr)255 LogicalResult ConvertNgrams(func::FuncOp func, llvm::StringRef api,
256                             FuncAttr attr) {
257   func.eraseBody();
258   func.addEntryBlock();
259   func->setAttr(kTFImplements, attr);
260   OpBuilder builder(func.getBody());
261   std::string custom_option_buffer;
262   if (failed(CreateNgramsCustomOption(func, attr.getAttrs(),
263                                       custom_option_buffer))) {
264     return failure();
265   }
266   auto op = builder.create<CustomOp>(
267       func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
268       api, CustomOption(&builder, custom_option_buffer));
269   builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
270   return success();
271 }
272 
VerifySgnnProjection(func::FuncOp func,FuncAttr attr)273 LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) {
274   if (func.getFunctionType().getNumInputs() != 2 ||
275       func.getFunctionType().getNumResults() != 1) {
276     return func.emitError() << "Mismatched number of inputs and outputs.";
277   }
278   auto values_type = GetInputType(func, 0);
279   if (!values_type || !values_type.getElementType().isa<StringType>()) {
280     return func.emitError() << "First input should be a string tensor";
281   }
282   auto row_splits_type = GetInputType(func, 1);
283   if (!row_splits_type ||
284       !row_splits_type.getElementType().isa<IntegerType>()) {
285     return func.emitError() << "Second input should be an integer tensor";
286   }
287 
288   auto hash_seed =
289       attr.getAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
290   if (!hash_seed) {
291     return func.emitError()
292            << "'hash_seed' attribute is not set or not an array";
293   }
294   auto output_type = GetResultType(func, 0);
295   if (!output_type || !output_type.getElementType().isa<FloatType>() ||
296       !RankEquals(output_type, 2)) {
297     return func.emitError() << "Output should be a 2D float tensor.";
298   }
299   if (output_type.getDimSize(1) != hash_seed.size()) {
300     return func.emitError()
301            << "Output 2nd dimension should be the num of hash seeds.";
302   }
303 
304   auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
305   if (!buckets) {
306     return func.emitError() << "'buckets' attribute is not set or not int";
307   }
308 
309   return success();
310 }
311 
CreateSgnnProjectionCustomOption(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)312 LogicalResult CreateSgnnProjectionCustomOption(
313     func::FuncOp func, DictionaryAttr attrs,
314     std::string& custom_option_buffer) {
315   flexbuffers::Builder fbb;
316   size_t start_map = fbb.StartMap();
317 
318   auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
319   auto vector_start = fbb.StartVector("hash_seed");
320   for (int i = 0; i < hash_seed.size(); i++) {
321     fbb.Add(static_cast<int32_t>(
322         (hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
323   }
324   fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
325 
326   auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
327   fbb.Int("buckets", buckets.getInt());
328 
329   fbb.EndMap(start_map);
330   fbb.Finish();
331   custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
332   return success();
333 }
334 
ConvertSgnnProjection(func::FuncOp func,llvm::StringRef api,FuncAttr attr)335 LogicalResult ConvertSgnnProjection(func::FuncOp func, llvm::StringRef api,
336                                     FuncAttr attr) {
337   // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
338   func.eraseBody();
339   func.addEntryBlock();
340   func->setAttr(kTFImplements, attr);
341   OpBuilder builder(func.getBody());
342   std::string custom_option_buffer;
343   if (failed(CreateSgnnProjectionCustomOption(func, attr.getAttrs(),
344                                               custom_option_buffer))) {
345     return failure();
346   }
347   auto op = builder.create<CustomOp>(
348       func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
349       api, CustomOption(&builder, custom_option_buffer));
350   builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
351   return success();
352 }
353 }  // namespace
354 
ConvertTFTextAPI(func::FuncOp func,llvm::StringRef api,FuncAttr attr)355 LogicalResult ConvertTFTextAPI(func::FuncOp func, llvm::StringRef api,
356                                FuncAttr attr) {
357   if (api.str() == kWhitespaceTokenizer) {
358     if (succeeded(VerifyWhitespaceTokenizer(func))) {
359       return ConvertWhitespaceTokenizer(func, api, attr);
360     }
361   } else if (api.str() == kNgrams) {
362     if (succeeded(VerifyNgrams(func))) {
363       return ConvertNgrams(func, api, attr);
364     }
365   } else if (api.str() == kCustomSgnnProjection) {
366     if (succeeded(VerifySgnnProjection(func, attr))) {
367       return ConvertSgnnProjection(func, api, attr);
368     }
369   }
370   return failure();
371 }
372 
IsTFTextRegistered(const tensorflow::OpRegistry * op_registery)373 bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery) {
374   const std::vector<std::string> kTFTextOps = {
375       "WhitespaceTokenizeWithOffsets",
376   };
377   for (const auto& iter : kTFTextOps) {
378     if (op_registery->LookUp(iter)) {
379       return true;
380     }
381   }
382   return false;
383 }
384 
385 }  // namespace TFL
386 }  // namespace mlir
387