1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/tftext_utils.h"
17
18 #include <string>
19
20 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/None.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/IR/Attributes.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/Location.h" // from @llvm-project
33 #include "mlir/IR/MLIRContext.h" // from @llvm-project
34 #include "mlir/IR/Matchers.h" // from @llvm-project
35 #include "mlir/IR/OpDefinition.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Support/LLVM.h" // from @llvm-project
40 #include "mlir/Support/LogicalResult.h" // from @llvm-project
41 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
43
44 namespace mlir {
45 namespace TFL {
46
47 namespace {
48
49 constexpr char kNgrams[] = "tftext:Ngrams";
50 constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
51 constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
52 constexpr char kTFImplements[] = "tf._implements";
53
54 using mlir::TF::FuncAttr;
55 using mlir::TF::StringType;
56
CustomOption(OpBuilder * builder,const std::string & content)57 inline ConstBytesAttr CustomOption(OpBuilder* builder,
58 const std::string& content) {
59 return ConstBytesAttr::get(builder->getContext(),
60 StringRef(content.data(), content.size()));
61 }
62
GetInputType(func::FuncOp func,int idx)63 inline TensorType GetInputType(func::FuncOp func, int idx) {
64 return func.getFunctionType().getInput(idx).dyn_cast_or_null<TensorType>();
65 }
66
GetResultType(func::FuncOp func,int idx)67 inline TensorType GetResultType(func::FuncOp func, int idx) {
68 return func.getFunctionType().getResult(idx).dyn_cast_or_null<TensorType>();
69 }
70
RankEquals(const TensorType & type,int rank)71 inline bool RankEquals(const TensorType& type, int rank) {
72 return type && type.hasRank() && type.getRank() == rank;
73 }
74
VerifyWhitespaceTokenizer(func::FuncOp func)75 LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) {
76 // In the case of input tensor with 0 rank.
77 // Whitespace tokenizer generates 1 output:
78 // * String tensor for tokens.
79 //
80 // In the case of 1-D input tensor,
81 // Whitespace tokenizer generates 2 outputs to make up a ragged tensor:
82 // * 1st output is the value of ragged tensor;
83 // * 2nd output is the offset.
84 //
85 // In the case of batched input tesnor,
86 // Whitespace tokenizer has 3 outputs to make up a nested ragged tensor:
87 // * 1st output is the value of ragged tensor;
88 // * 2nd output is the inner offset;
89 // * 3rd output is the outer offset.
90 auto input_type = GetInputType(func, 0);
91 if (!input_type || !input_type.getElementType().isa<StringType>() ||
92 !input_type.hasRank()) {
93 return func.emitError() << "Input should be a string tensor";
94 }
95
96 const std::vector<int> kValidNumOfOutput = {1, 2, 3};
97 if (input_type.getRank() >= kValidNumOfOutput.size()) {
98 return func.emitError()
99 << "Unrecognized input rank: " << input_type.getRank();
100 }
101 if (func.getNumResults() != kValidNumOfOutput[input_type.getRank()]) {
102 return func.emitError()
103 << "Expect " << kValidNumOfOutput[input_type.getRank()]
104 << "output(s) when input has rank " << input_type.getRank();
105 }
106
107 auto value_type = GetResultType(func, 0);
108 if (!RankEquals(value_type, 1) ||
109 !value_type.getElementType().isa<StringType>()) {
110 return func.emitError() << "1st output should be string tensor";
111 }
112 if (func.getNumResults() > 1) {
113 auto offset_type = GetResultType(func, 1);
114 if (!RankEquals(offset_type, 1) ||
115 !offset_type.getElementType().isInteger(64)) {
116 return func.emitError() << "2nd output should be int64 tensor";
117 }
118 }
119 if (func.getNumResults() > 2) {
120 auto offset_type = GetResultType(func, 2);
121 if (!RankEquals(offset_type, 1) ||
122 !offset_type.getElementType().isInteger(64)) {
123 return func.emitError() << "3rd output should be int64 tensor";
124 }
125 }
126
127 return success();
128 }
129
ConvertWhitespaceTokenizer(func::FuncOp func,llvm::StringRef api,FuncAttr attr)130 LogicalResult ConvertWhitespaceTokenizer(func::FuncOp func, llvm::StringRef api,
131 FuncAttr attr) {
132 func.eraseBody();
133 func.addEntryBlock();
134 func->setAttr(kTFImplements, attr);
135 OpBuilder builder(func.getBody());
136 std::string empty_option_buffer;
137 auto op = builder.create<CustomOp>(
138 func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
139 api, CustomOption(&builder, empty_option_buffer));
140 builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
141 return success();
142 }
143
VerifyNgrams(func::FuncOp func)144 LogicalResult VerifyNgrams(func::FuncOp func) {
145 // The inputs and outputs should be the same:
146 // * A string tensor for tokens/ragged tensor values.
147 // * Zero or more row_split tensors.
148 constexpr int kValues = 0;
149 constexpr int kRowSplits = 1;
150
151 if (func.getFunctionType().getInputs().size() !=
152 func.getFunctionType().getResults().size()) {
153 return func.emitError() << "Mismatched number of inputs and outputs.";
154 }
155
156 int row_splits = func.getFunctionType().getInputs().size() - kRowSplits;
157 if (row_splits == 0) {
158 auto input_values = GetInputType(func, kValues);
159 if (!input_values || !input_values.getElementType().isa<StringType>()) {
160 return func.emitError()
161 << "Input " << kValues << " should be a string tensor";
162 }
163 auto output_values = GetResultType(func, kValues);
164 if (!output_values || !output_values.getElementType().isa<StringType>()) {
165 return func.emitError()
166 << "Output " << kValues << " should be a string tensor";
167 }
168
169 if (input_values.hasRank() && output_values.hasRank() &&
170 input_values.getRank() != output_values.getRank()) {
171 return func.emitError() << "Input " << kValues << " and output "
172 << kValues << " should have the same rank";
173 }
174 } else {
175 auto input_values = GetInputType(func, kValues);
176 if (!RankEquals(input_values, 1) ||
177 !input_values.getElementType().isa<StringType>()) {
178 return func.emitError()
179 << "Input " << kValues << " should be a 1D string tensor";
180 }
181 auto output_values = GetResultType(func, kValues);
182 if (!RankEquals(output_values, 1) ||
183 !output_values.getElementType().isa<StringType>()) {
184 return func.emitError()
185 << "Output " << kValues << " should be a 1D string tensor";
186 }
187
188 for (int i = 0; i < row_splits; ++i) {
189 const int row_index = i + kRowSplits;
190 auto input_row_splits = GetInputType(func, row_index);
191 if (!RankEquals(input_row_splits, 1) ||
192 !input_row_splits.getElementType().isInteger(64)) {
193 return func.emitError()
194 << "Input " << row_index << " should be a 1D int64 tensor";
195 }
196 auto output_row_splits = GetResultType(func, row_index);
197 if (!RankEquals(output_row_splits, 1) ||
198 !output_row_splits.getElementType().isInteger(64)) {
199 return func.emitError()
200 << "Output " << row_index << " should be a 1D int64 tensor";
201 }
202 }
203 }
204
205 return success();
206 }
207
CreateNgramsCustomOption(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)208 LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs,
209 std::string& custom_option_buffer) {
210 flexbuffers::Builder fbb;
211 size_t start_map = fbb.StartMap();
212
213 auto width = attrs.get("width").dyn_cast_or_null<IntegerAttr>();
214 if (!width) {
215 return func.emitError() << "'width' attribute is not set or not an integer";
216 }
217 fbb.Int("width", width.getInt());
218
219 auto string_separator =
220 attrs.get("string_separator").dyn_cast_or_null<StringAttr>();
221 if (!string_separator) {
222 return func.emitError()
223 << "'string_separator' attribute is not set or not a string";
224 }
225 // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
226 // strings expect NUL terminated strings.
227 std::string string_separator_str(string_separator.getValue().data(),
228 string_separator.getValue().size());
229 fbb.String("string_separator", string_separator_str);
230
231 auto axis = attrs.get("axis").dyn_cast_or_null<IntegerAttr>();
232 if (!axis) {
233 return func.emitError() << "'axis' attribute is not set or not an integer";
234 }
235 fbb.Int("axis", axis.getInt());
236
237 auto reduction_type =
238 attrs.get("reduction_type").dyn_cast_or_null<StringAttr>();
239 if (!reduction_type) {
240 return func.emitError()
241 << "'reduction_type' attribute is not set or not a string";
242 }
243 // StringAttrs are not guaranteed to be NUL terminated, but flexbuffers
244 // strings expect NUL terminated strings.
245 std::string reduction_type_str(reduction_type.getValue().data(),
246 reduction_type.getValue().size());
247 fbb.String("reduction_type", reduction_type_str);
248
249 fbb.EndMap(start_map);
250 fbb.Finish();
251 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
252 return success();
253 }
254
ConvertNgrams(func::FuncOp func,llvm::StringRef api,FuncAttr attr)255 LogicalResult ConvertNgrams(func::FuncOp func, llvm::StringRef api,
256 FuncAttr attr) {
257 func.eraseBody();
258 func.addEntryBlock();
259 func->setAttr(kTFImplements, attr);
260 OpBuilder builder(func.getBody());
261 std::string custom_option_buffer;
262 if (failed(CreateNgramsCustomOption(func, attr.getAttrs(),
263 custom_option_buffer))) {
264 return failure();
265 }
266 auto op = builder.create<CustomOp>(
267 func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
268 api, CustomOption(&builder, custom_option_buffer));
269 builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
270 return success();
271 }
272
VerifySgnnProjection(func::FuncOp func,FuncAttr attr)273 LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) {
274 if (func.getFunctionType().getNumInputs() != 2 ||
275 func.getFunctionType().getNumResults() != 1) {
276 return func.emitError() << "Mismatched number of inputs and outputs.";
277 }
278 auto values_type = GetInputType(func, 0);
279 if (!values_type || !values_type.getElementType().isa<StringType>()) {
280 return func.emitError() << "First input should be a string tensor";
281 }
282 auto row_splits_type = GetInputType(func, 1);
283 if (!row_splits_type ||
284 !row_splits_type.getElementType().isa<IntegerType>()) {
285 return func.emitError() << "Second input should be an integer tensor";
286 }
287
288 auto hash_seed =
289 attr.getAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
290 if (!hash_seed) {
291 return func.emitError()
292 << "'hash_seed' attribute is not set or not an array";
293 }
294 auto output_type = GetResultType(func, 0);
295 if (!output_type || !output_type.getElementType().isa<FloatType>() ||
296 !RankEquals(output_type, 2)) {
297 return func.emitError() << "Output should be a 2D float tensor.";
298 }
299 if (output_type.getDimSize(1) != hash_seed.size()) {
300 return func.emitError()
301 << "Output 2nd dimension should be the num of hash seeds.";
302 }
303
304 auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
305 if (!buckets) {
306 return func.emitError() << "'buckets' attribute is not set or not int";
307 }
308
309 return success();
310 }
311
CreateSgnnProjectionCustomOption(func::FuncOp func,DictionaryAttr attrs,std::string & custom_option_buffer)312 LogicalResult CreateSgnnProjectionCustomOption(
313 func::FuncOp func, DictionaryAttr attrs,
314 std::string& custom_option_buffer) {
315 flexbuffers::Builder fbb;
316 size_t start_map = fbb.StartMap();
317
318 auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
319 auto vector_start = fbb.StartVector("hash_seed");
320 for (int i = 0; i < hash_seed.size(); i++) {
321 fbb.Add(static_cast<int32_t>(
322 (hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
323 }
324 fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
325
326 auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
327 fbb.Int("buckets", buckets.getInt());
328
329 fbb.EndMap(start_map);
330 fbb.Finish();
331 custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
332 return success();
333 }
334
ConvertSgnnProjection(func::FuncOp func,llvm::StringRef api,FuncAttr attr)335 LogicalResult ConvertSgnnProjection(func::FuncOp func, llvm::StringRef api,
336 FuncAttr attr) {
337 // See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
338 func.eraseBody();
339 func.addEntryBlock();
340 func->setAttr(kTFImplements, attr);
341 OpBuilder builder(func.getBody());
342 std::string custom_option_buffer;
343 if (failed(CreateSgnnProjectionCustomOption(func, attr.getAttrs(),
344 custom_option_buffer))) {
345 return failure();
346 }
347 auto op = builder.create<CustomOp>(
348 func.getLoc(), func.getFunctionType().getResults(), func.getArguments(),
349 api, CustomOption(&builder, custom_option_buffer));
350 builder.create<func::ReturnOp>(func.getLoc(), op.getResults());
351 return success();
352 }
353 } // namespace
354
ConvertTFTextAPI(func::FuncOp func,llvm::StringRef api,FuncAttr attr)355 LogicalResult ConvertTFTextAPI(func::FuncOp func, llvm::StringRef api,
356 FuncAttr attr) {
357 if (api.str() == kWhitespaceTokenizer) {
358 if (succeeded(VerifyWhitespaceTokenizer(func))) {
359 return ConvertWhitespaceTokenizer(func, api, attr);
360 }
361 } else if (api.str() == kNgrams) {
362 if (succeeded(VerifyNgrams(func))) {
363 return ConvertNgrams(func, api, attr);
364 }
365 } else if (api.str() == kCustomSgnnProjection) {
366 if (succeeded(VerifySgnnProjection(func, attr))) {
367 return ConvertSgnnProjection(func, api, attr);
368 }
369 }
370 return failure();
371 }
372
IsTFTextRegistered(const tensorflow::OpRegistry * op_registery)373 bool IsTFTextRegistered(const tensorflow::OpRegistry* op_registery) {
374 const std::vector<std::string> kTFTextOps = {
375 "WhitespaceTokenizeWithOffsets",
376 };
377 for (const auto& iter : kTFTextOps) {
378 if (op_registery->LookUp(iter)) {
379 return true;
380 }
381 }
382 return false;
383 }
384
385 } // namespace TFL
386 } // namespace mlir
387