1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
17
18 #include "android-base/file.h"
19 #include "tensorflow_lite_support/cc/port/gmock.h"
20 #include "tensorflow_lite_support/cc/port/gtest.h"
21 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
22
23 namespace tflite {
24 namespace support {
25 namespace text {
26 namespace tokenizer {
27
28 using ::android::base::GetExecutableDirectory;
29 using ::testing::ElementsAre;
30 using ::tflite::task::core::LoadBinaryContent;
31
32 namespace {
33 constexpr char kTestRegexVocabSubPath[] =
34 "/tensorflow_lite_support/cc/test/testdata/task/text/"
35 "vocab_for_regex_tokenizer.txt";
36
37 constexpr char kTestRegexEmptyVocabSubPath[] =
38 "/tensorflow_lite_support/cc/test/testdata/task/text/"
39 "empty_vocab_for_regex_tokenizer.txt";
40
41 constexpr char kRegex[] = "[^\\w\\']+";
42
TEST(RegexTokenizerTest,TestTokenize)43 TEST(RegexTokenizerTest, TestTokenize) {
44 std::string test_regex_vocab_path =
45 absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
46 auto tokenizer =
47 absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
48 auto results = tokenizer->Tokenize("good morning, i'm your teacher.\n");
49 EXPECT_THAT(results.subwords,
50 ElementsAre("good", "morning", "i'm", "your", "teacher"));
51 }
52
TEST(RegexTokenizerTest,TestTokenizeFromFileBuffer)53 TEST(RegexTokenizerTest, TestTokenizeFromFileBuffer) {
54 std::string test_regex_vocab_path =
55 absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
56 std::string buffer = LoadBinaryContent(test_regex_vocab_path.c_str());
57 auto tokenizer =
58 absl::make_unique<RegexTokenizer>(kRegex, buffer.data(), buffer.size());
59 auto results = tokenizer->Tokenize("good morning, i'm your teacher.\n");
60 EXPECT_THAT(results.subwords,
61 ElementsAre("good", "morning", "i'm", "your", "teacher"));
62 }
63
TEST(RegexTokenizerTest,TestLookupId)64 TEST(RegexTokenizerTest, TestLookupId) {
65 std::string test_regex_vocab_path =
66 absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
67 auto tokenizer =
68 absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
69 std::vector<std::string> subwords = {"good", "morning", "i'm", "your",
70 "teacher"};
71 std::vector<int> true_ids = {52, 1972, 146, 129, 1750};
72 int id;
73 for (int i = 0; i < subwords.size(); i++) {
74 ASSERT_TRUE(tokenizer->LookupId(subwords[i], &id));
75 ASSERT_EQ(id, true_ids[i]);
76 }
77 }
78
TEST(RegexTokenizerTest,TestLookupWord)79 TEST(RegexTokenizerTest, TestLookupWord) {
80 std::string test_regex_vocab_path =
81 absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
82 auto tokenizer =
83 absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
84 std::vector<int> ids = {52, 1972, 146, 129, 1750};
85 std::vector<std::string> subwords = {"good", "morning", "i'm", "your",
86 "teacher"};
87 absl::string_view result;
88 for (int i = 0; i < ids.size(); i++) {
89 ASSERT_TRUE(tokenizer->LookupWord(ids[i], &result));
90 ASSERT_EQ(result, subwords[i]);
91 }
92 }
93
TEST(RegexTokenizerTest,TestGetSpecialTokens)94 TEST(RegexTokenizerTest, TestGetSpecialTokens) {
95 // The vocab the following tokens:
96 // <PAD> 0
97 // <START> 1
98 // <UNKNOWN> 2
99 std::string test_regex_vocab_path =
100 absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
101 auto tokenizer =
102 absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
103
104 int start_token;
105 ASSERT_TRUE(tokenizer->GetStartToken(&start_token));
106 ASSERT_EQ(start_token, 1);
107
108 int pad_token;
109 ASSERT_TRUE(tokenizer->GetPadToken(&pad_token));
110 ASSERT_EQ(pad_token, 0);
111
112 int unknown_token;
113 ASSERT_TRUE(tokenizer->GetUnknownToken(&unknown_token));
114 ASSERT_EQ(unknown_token, 2);
115 }
116
TEST(RegexTokenizerTest,TestGetSpecialTokensFailure)117 TEST(RegexTokenizerTest, TestGetSpecialTokensFailure) {
118 std::string test_regex_empty_vocab_path =
119 absl::StrCat(GetExecutableDirectory(), kTestRegexEmptyVocabSubPath);
120 auto tokenizer =
121 absl::make_unique<RegexTokenizer>(kRegex, test_regex_empty_vocab_path);
122
123 int start_token;
124 ASSERT_FALSE(tokenizer->GetStartToken(&start_token));
125
126 int pad_token;
127 ASSERT_FALSE(tokenizer->GetPadToken(&pad_token));
128
129 int unknown_token;
130 ASSERT_FALSE(tokenizer->GetUnknownToken(&unknown_token));
131 }
132
133 } // namespace
134
135 } // namespace tokenizer
136 } // namespace text
137 } // namespace support
138 } // namespace tflite
139