xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/test/text/regex_tokenizer_test.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
17 
18 #include "android-base/file.h"
19 #include "tensorflow_lite_support/cc/port/gmock.h"
20 #include "tensorflow_lite_support/cc/port/gtest.h"
21 #include "tensorflow_lite_support/cc/task/core/task_utils.h"
22 
23 namespace tflite {
24 namespace support {
25 namespace text {
26 namespace tokenizer {
27 
28 using ::android::base::GetExecutableDirectory;
29 using ::testing::ElementsAre;
30 using ::tflite::task::core::LoadBinaryContent;
31 
32 namespace {
33 constexpr char kTestRegexVocabSubPath[] =
34     "/tensorflow_lite_support/cc/test/testdata/task/text/"
35     "vocab_for_regex_tokenizer.txt";
36 
37 constexpr char kTestRegexEmptyVocabSubPath[] =
38     "/tensorflow_lite_support/cc/test/testdata/task/text/"
39     "empty_vocab_for_regex_tokenizer.txt";
40 
41 constexpr char kRegex[] = "[^\\w\\']+";
42 
TEST(RegexTokenizerTest,TestTokenize)43 TEST(RegexTokenizerTest, TestTokenize) {
44   std::string test_regex_vocab_path =
45       absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
46   auto tokenizer =
47       absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
48   auto results = tokenizer->Tokenize("good    morning, i'm your teacher.\n");
49   EXPECT_THAT(results.subwords,
50               ElementsAre("good", "morning", "i'm", "your", "teacher"));
51 }
52 
TEST(RegexTokenizerTest,TestTokenizeFromFileBuffer)53 TEST(RegexTokenizerTest, TestTokenizeFromFileBuffer) {
54   std::string test_regex_vocab_path =
55       absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
56   std::string buffer = LoadBinaryContent(test_regex_vocab_path.c_str());
57   auto tokenizer =
58       absl::make_unique<RegexTokenizer>(kRegex, buffer.data(), buffer.size());
59   auto results = tokenizer->Tokenize("good    morning, i'm your teacher.\n");
60   EXPECT_THAT(results.subwords,
61               ElementsAre("good", "morning", "i'm", "your", "teacher"));
62 }
63 
TEST(RegexTokenizerTest,TestLookupId)64 TEST(RegexTokenizerTest, TestLookupId) {
65   std::string test_regex_vocab_path =
66       absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
67   auto tokenizer =
68       absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
69   std::vector<std::string> subwords = {"good", "morning", "i'm", "your",
70                                        "teacher"};
71   std::vector<int> true_ids = {52, 1972, 146, 129, 1750};
72   int id;
73   for (int i = 0; i < subwords.size(); i++) {
74     ASSERT_TRUE(tokenizer->LookupId(subwords[i], &id));
75     ASSERT_EQ(id, true_ids[i]);
76   }
77 }
78 
TEST(RegexTokenizerTest,TestLookupWord)79 TEST(RegexTokenizerTest, TestLookupWord) {
80   std::string test_regex_vocab_path =
81       absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
82   auto tokenizer =
83       absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
84   std::vector<int> ids = {52, 1972, 146, 129, 1750};
85   std::vector<std::string> subwords = {"good", "morning", "i'm", "your",
86                                        "teacher"};
87   absl::string_view result;
88   for (int i = 0; i < ids.size(); i++) {
89     ASSERT_TRUE(tokenizer->LookupWord(ids[i], &result));
90     ASSERT_EQ(result, subwords[i]);
91   }
92 }
93 
TEST(RegexTokenizerTest,TestGetSpecialTokens)94 TEST(RegexTokenizerTest, TestGetSpecialTokens) {
95   // The vocab the following tokens:
96   // <PAD> 0
97   // <START> 1
98   // <UNKNOWN> 2
99   std::string test_regex_vocab_path =
100       absl::StrCat(GetExecutableDirectory(), kTestRegexVocabSubPath);
101   auto tokenizer =
102       absl::make_unique<RegexTokenizer>(kRegex, test_regex_vocab_path);
103 
104   int start_token;
105   ASSERT_TRUE(tokenizer->GetStartToken(&start_token));
106   ASSERT_EQ(start_token, 1);
107 
108   int pad_token;
109   ASSERT_TRUE(tokenizer->GetPadToken(&pad_token));
110   ASSERT_EQ(pad_token, 0);
111 
112   int unknown_token;
113   ASSERT_TRUE(tokenizer->GetUnknownToken(&unknown_token));
114   ASSERT_EQ(unknown_token, 2);
115 }
116 
TEST(RegexTokenizerTest,TestGetSpecialTokensFailure)117 TEST(RegexTokenizerTest, TestGetSpecialTokensFailure) {
118   std::string test_regex_empty_vocab_path =
119       absl::StrCat(GetExecutableDirectory(), kTestRegexEmptyVocabSubPath);
120   auto tokenizer =
121       absl::make_unique<RegexTokenizer>(kRegex, test_regex_empty_vocab_path);
122 
123   int start_token;
124   ASSERT_FALSE(tokenizer->GetStartToken(&start_token));
125 
126   int pad_token;
127   ASSERT_FALSE(tokenizer->GetPadToken(&pad_token));
128 
129   int unknown_token;
130   ASSERT_FALSE(tokenizer->GetUnknownToken(&unknown_token));
131 }
132 
133 }  // namespace
134 
135 }  // namespace tokenizer
136 }  // namespace text
137 }  // namespace support
138 }  // namespace tflite
139