1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/translate/translate.h"
18
19 #include <memory>
20
21 #include "annotator/model_generated.h"
22 #include "utils/test-data-test-utils.h"
23 #include "lang_id/fb_model/lang-id-from-fb.h"
24 #include "lang_id/lang-id.h"
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27
28 namespace libtextclassifier3 {
29 namespace {
30
31 using testing::AllOf;
32 using testing::Field;
33
CreateOptionsData(ModeFlag enabled_modes)34 const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
35 TranslateAnnotatorOptionsT options;
36 options.enabled = true;
37 options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
38 options.backoff_options.reset(
39 new TranslateAnnotatorOptions_::BackoffOptionsT());
40 options.enabled_modes = enabled_modes;
41
42 flatbuffers::FlatBufferBuilder builder;
43 builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
44 return new flatbuffers::DetachedBuffer(builder.Release());
45 }
46
TestingTranslateAnnotatorOptions(ModeFlag enabled_modes)47 const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions(
48 ModeFlag enabled_modes) {
49 static const flatbuffers::DetachedBuffer* options_data_classification =
50 CreateOptionsData(ModeFlag_CLASSIFICATION);
51 static const flatbuffers::DetachedBuffer* options_data_none =
52 CreateOptionsData(ModeFlag_NONE);
53
54 if (enabled_modes == ModeFlag_CLASSIFICATION) {
55 return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
56 options_data_classification->data());
57 } else {
58 return flatbuffers::GetRoot<TranslateAnnotatorOptions>(
59 options_data_none->data());
60 }
61 }
62
63 class TestingTranslateAnnotator : public TranslateAnnotator {
64 public:
65 // Make these protected members public for tests.
66 using TranslateAnnotator::BackoffDetectLanguages;
67 using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
68 using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
69 using TranslateAnnotator::TranslateAnnotator;
70 };
71
GetModelPath()72 std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
73
74 class TranslateAnnotatorTest : public ::testing::Test {
75 protected:
TranslateAnnotatorTest(ModeFlag enabled_modes=ModeFlag_CLASSIFICATION)76 explicit TranslateAnnotatorTest(
77 ModeFlag enabled_modes = ModeFlag_CLASSIFICATION)
78 : INIT_UNILIB_FOR_TESTING(unilib_),
79 langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
80 GetModelPath() + "lang_id.smfb")),
81 translate_annotator_(TestingTranslateAnnotatorOptions(enabled_modes),
82 langid_model_.get(), &unilib_) {}
83
84 UniLib unilib_;
85 std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
86 TestingTranslateAnnotator translate_annotator_;
87 };
88
89 class TranslateAnnotatorForNoneTest : public TranslateAnnotatorTest {
90 protected:
TranslateAnnotatorForNoneTest()91 TranslateAnnotatorForNoneTest() : TranslateAnnotatorTest(ModeFlag_NONE) {}
92 };
93
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishGetsTranslateActionForCzech)94 TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
95 ClassificationResult classification;
96 EXPECT_TRUE(translate_annotator_.ClassifyText(
97 UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
98 "en", &classification));
99
100 EXPECT_THAT(classification,
101 AllOf(Field(&ClassificationResult::collection, "translate")));
102 const EntityData* entity_data =
103 GetEntityData(classification.serialized_entity_data.data());
104 const auto predictions =
105 entity_data->translate()->language_prediction_results();
106 EXPECT_EQ(predictions->size(), 1);
107 EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
108 EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
109 EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
110 }
111
TEST_F(TranslateAnnotatorTest,EntityDataIsSet)112 TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
113 ClassificationResult classification;
114 EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
115 {0, 2}, "en", &classification));
116
117 EXPECT_THAT(classification,
118 AllOf(Field(&ClassificationResult::collection, "translate")));
119 const EntityData* entity_data =
120 GetEntityData(classification.serialized_entity_data.data());
121 const auto predictions =
122 entity_data->translate()->language_prediction_results();
123 EXPECT_EQ(predictions->size(), 2);
124 EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
125 EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
126 EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
127 EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
128 EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
129 predictions->Get(1)->confidence_score());
130 }
131
TEST_F(TranslateAnnotatorForNoneTest,ClassifyTextDisabledClassificationReturnsFalse)132 TEST_F(TranslateAnnotatorForNoneTest,
133 ClassifyTextDisabledClassificationReturnsFalse) {
134 ClassificationResult classification;
135 EXPECT_FALSE(translate_annotator_.ClassifyText(
136 UTF8ToUnicodeText("学校"), {0, 2}, "en", &classification));
137 }
138
TEST_F(TranslateAnnotatorTest,WhenSpeaksEnglishDoesntGetTranslateActionForEnglish)139 TEST_F(TranslateAnnotatorTest,
140 WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
141 ClassificationResult classification;
142 EXPECT_FALSE(translate_annotator_.ClassifyText(
143 UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
144 &classification));
145 }
146
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech)147 TEST_F(TranslateAnnotatorTest,
148 WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
149 ClassificationResult classification;
150 EXPECT_TRUE(translate_annotator_.ClassifyText(
151 UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
152 "de,en,ja", &classification));
153
154 EXPECT_THAT(classification,
155 AllOf(Field(&ClassificationResult::collection, "translate")));
156 }
157
TEST_F(TranslateAnnotatorTest,WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish)158 TEST_F(TranslateAnnotatorTest,
159 WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
160 ClassificationResult classification;
161 EXPECT_FALSE(translate_annotator_.ClassifyText(
162 UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
163 &classification));
164 }
165
TEST_F(TranslateAnnotatorTest,FindIndexOfNextWhitespaceOrPunctuation)166 TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
167 const UnicodeText text =
168 UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
169
170 EXPECT_EQ(
171 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
172 text.begin());
173 EXPECT_EQ(
174 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
175 text.end());
176 EXPECT_EQ(
177 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
178 std::next(text.begin(), 6));
179 EXPECT_EQ(
180 translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
181 std::next(text.begin(), 13));
182 }
183
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringAroundSpan)184 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
185 const UnicodeText text =
186 UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
187
188 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
189 text, {35, 37}, /*minimum_length=*/100),
190 text);
191 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
192 text, {35, 37}, /*minimum_length=*/0),
193 UTF8ToUnicodeText("ač"));
194 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
195 text, {35, 37}, /*minimum_length=*/3),
196 UTF8ToUnicodeText("stříkaček"));
197 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
198 text, {35, 37}, /*minimum_length=*/10),
199 UTF8ToUnicodeText("stříkaček"));
200 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
201 text, {35, 37}, /*minimum_length=*/11),
202 UTF8ToUnicodeText("stříbrných stříkaček"));
203
204 const UnicodeText text_no_whitespace =
205 UTF8ToUnicodeText("reallyreallylongstring");
206 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
207 text_no_whitespace, {10, 11}, /*minimum_length=*/2),
208 UTF8ToUnicodeText("reallyreallylongstring"));
209 }
210
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringWhitespaceText)211 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
212 const UnicodeText text = UTF8ToUnicodeText(" ");
213
214 // Shouldn't modify the selection in case it's all whitespace.
215 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
216 text, {5, 7}, /*minimum_length=*/3),
217 UTF8ToUnicodeText(" "));
218 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
219 text, {5, 5}, /*minimum_length=*/1),
220 UTF8ToUnicodeText(""));
221 }
222
TEST_F(TranslateAnnotatorTest,TokenAlignedSubstringMostlyWhitespaceText)223 TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
224 const UnicodeText text = UTF8ToUnicodeText("a a");
225
226 // Should still select the whole text even if pointing to whitespace
227 // initially.
228 EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
229 text, {5, 7}, /*minimum_length=*/11),
230 UTF8ToUnicodeText("a a"));
231 }
232
233 } // namespace
234 } // namespace libtextclassifier3
235