xref: /aosp_15_r20/external/libtextclassifier/native/annotator/translate/translate_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/translate/translate.h"
18 
19 #include <memory>
20 
21 #include "annotator/model_generated.h"
22 #include "utils/test-data-test-utils.h"
23 #include "lang_id/fb_model/lang-id-from-fb.h"
24 #include "lang_id/lang-id.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 
28 namespace libtextclassifier3 {
29 namespace {
30 
31 using testing::AllOf;
32 using testing::Field;
33 
CreateOptionsData(ModeFlag enabled_modes)34 const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
35   TranslateAnnotatorOptionsT options;
36   options.enabled = true;
37   options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
38   options.backoff_options.reset(
39       new TranslateAnnotatorOptions_::BackoffOptionsT());
40   options.enabled_modes = enabled_modes;
41 
42   flatbuffers::FlatBufferBuilder builder;
43   builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
44   return new flatbuffers::DetachedBuffer(builder.Release());
45 }
46 
TestingTranslateAnnotatorOptions(ModeFlag enabled_modes)47 const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions(
48     ModeFlag enabled_modes) {
49   static const flatbuffers::DetachedBuffer* options_data_classification =
50       CreateOptionsData(ModeFlag_CLASSIFICATION);
51   static const flatbuffers::DetachedBuffer* options_data_none =
52       CreateOptionsData(ModeFlag_NONE);
53 
54   if (enabled_modes == ModeFlag_CLASSIFICATION) {
55     return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
56         options_data_classification->data());
57   } else {
58     return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
59         options_data_none->data());
60   }
61 }
62 
63 class TestingTranslateAnnotator : public TranslateAnnotator {
64  public:
65   // Make these protected members public for tests.
66   using TranslateAnnotator::BackoffDetectLanguages;
67   using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
68   using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
69   using TranslateAnnotator::TranslateAnnotator;
70 };
71 
GetModelPath()72 std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
73 
74 class TranslateAnnotatorTest : public ::testing::Test {
75  protected:
TranslateAnnotatorTest(ModeFlag enabled_modes=ModeFlag_CLASSIFICATION)76   explicit TranslateAnnotatorTest(
77       ModeFlag enabled_modes = ModeFlag_CLASSIFICATION)
78       : INIT_UNILIB_FOR_TESTING(unilib_),
79         langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
80             GetModelPath() + "lang_id.smfb")),
81         translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes),
82                              langid_model_.get(), &unilib_) {}
83 
84   UniLib unilib_;
85   std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
86   TestingTranslateAnnotator translate_annotator_;
87 };
88 
89 class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest {
90  protected:
TranslateAnnotatorForNoneTest()91   TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {}
92 };
93 
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishGetsTranslateActionForCzech)94 TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
95   ClassificationResult classification;
96   EXPECT_TRUE(translate_annotator_.ClassifyText(
97       UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
98       "en", &classification));
99 
100   EXPECT_THAT(classification,
101               AllOf(Field(&ClassificationResult::collection, "translate")));
102   const EntityData* entity_data =
103       GetEntityData(classification.serialized_entity_data.data());
104   const auto predictions =
105       entity_data->translate()->language_prediction_results();
106   EXPECT_EQ(predictions->size(), 1);
107   EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
108   EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
109   EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
110 }
111 
TEST_F(TranslateAnnotatorTest,EntityDataIsSet)112 TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
113   ClassificationResult classification;
114   EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
115                                                 {0, 2}, "en", &classification));
116 
117   EXPECT_THAT(classification,
118               AllOf(Field(&ClassificationResult::collection, "translate")));
119   const EntityData* entity_data =
120       GetEntityData(classification.serialized_entity_data.data());
121   const auto predictions =
122       entity_data->translate()->language_prediction_results();
123   EXPECT_EQ(predictions->size(), 2);
124   EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
125   EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
126   EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
127   EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
128   EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
129               predictions->Get(1)->confidence_score());
130 }
131 
TEST_F(TranslateAnnotatorForNoneTest,ClassifyTextDisabledClassificationReturnsFalse)132 TEST_F(TranslateAnnotatorForNoneTest,
133        ClassifyTextDisabledClassificationReturnsFalse) {
134   ClassificationResult classification;
135   EXPECT_FALSE(translate_annotator_.ClassifyText(
136       UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification));
137 }
138 
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishDoesntGetTranslateActionForEnglish)139 TEST_F(TranslateAnnotatorTest,
140        WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
141   ClassificationResult classification;
142   EXPECT_FALSE(translate_annotator_.ClassifyText(
143       UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
144       &classification));
145 }
146 
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech)147 TEST_F(TranslateAnnotatorTest,
148        WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
149   ClassificationResult classification;
150   EXPECT_TRUE(translate_annotator_.ClassifyText(
151       UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
152       "de,en,ja", &classification));
153 
154   EXPECT_THAT(classification,
155               AllOf(Field(&ClassificationResult::collection, "translate")));
156 }
157 
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish)158 TEST_F(TranslateAnnotatorTest,
159        WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
160   ClassificationResult classification;
161   EXPECT_FALSE(translate_annotator_.ClassifyText(
162       UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
163       &classification));
164 }
165 
TEST_F(TranslateAnnotatorTest,FindIndexOfNextWhitespaceOrPunctuation)166 TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
167   const UnicodeText text =
168       UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
169 
170   EXPECT_EQ(
171       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
172       text.begin());
173   EXPECT_EQ(
174       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
175       text.end());
176   EXPECT_EQ(
177       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
178       std::next(text.begin(), 6));
179   EXPECT_EQ(
180       translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
181       std::next(text.begin(), 13));
182 }
183 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringAroundSpan)184 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
185   const UnicodeText text =
186       UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
187 
188   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
189                 text, {35, 37}, /*minimum_length=*/100),
190             text);
191   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
192                 text, {35, 37}, /*minimum_length=*/0),
193             UTF8ToUnicodeText("ač"));
194   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
195                 text, {35, 37}, /*minimum_length=*/3),
196             UTF8ToUnicodeText("stříkaček"));
197   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
198                 text, {35, 37}, /*minimum_length=*/10),
199             UTF8ToUnicodeText("stříkaček"));
200   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
201                 text, {35, 37}, /*minimum_length=*/11),
202             UTF8ToUnicodeText("stříbrných stříkaček"));
203 
204   const UnicodeText text_no_whitespace =
205       UTF8ToUnicodeText("reallyreallylongstring");
206   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
207                 text_no_whitespace, {10, 11}, /*minimum_length=*/2),
208             UTF8ToUnicodeText("reallyreallylongstring"));
209 }
210 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringWhitespaceText)211 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
212   const UnicodeText text = UTF8ToUnicodeText("            ");
213 
214   // Shouldn't modify the selection in case it's all whitespace.
215   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
216                 text, {5, 7}, /*minimum_length=*/3),
217             UTF8ToUnicodeText("  "));
218   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
219                 text, {5, 5}, /*minimum_length=*/1),
220             UTF8ToUnicodeText(""));
221 }
222 
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringMostlyWhitespaceText)223 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
224   const UnicodeText text = UTF8ToUnicodeText("a        a");
225 
226   // Should still select the whole text even if pointing to whitespace
227   // initially.
228   EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
229                 text, {5, 7}, /*minimum_length=*/11),
230             UTF8ToUnicodeText("a        a"));
231 }
232 
233 }  // namespace
234 }  // namespace libtextclassifier3
235