1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "annotator/duration/duration.h"
18
19 #include <cstddef>
20 #include <string>
21 #include <vector>
22
23 #include "annotator/collections.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types-test-util.h"
26 #include "annotator/types.h"
27 #include "utils/tokenizer-utils.h"
28 #include "utils/utf8/unicodetext.h"
29 #include "utils/utf8/unilib.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32
33 namespace libtextclassifier3 {
34 namespace {
35
36 using testing::AllOf;
37 using testing::ElementsAre;
38 using testing::Field;
39 using testing::IsEmpty;
40
41 namespace {
CreateOptionsData(ModeFlag enabled_modes)42 const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
43 DurationAnnotatorOptionsT options;
44 options.enabled = true;
45 options.enabled_modes = enabled_modes;
46
47 options.week_expressions.push_back("week");
48 options.week_expressions.push_back("weeks");
49
50 options.day_expressions.push_back("day");
51 options.day_expressions.push_back("days");
52
53 options.hour_expressions.push_back("hour");
54 options.hour_expressions.push_back("hours");
55
56 options.minute_expressions.push_back("minute");
57 options.minute_expressions.push_back("minutes");
58
59 options.second_expressions.push_back("second");
60 options.second_expressions.push_back("seconds");
61
62 options.filler_expressions.push_back("and");
63 options.filler_expressions.push_back("a");
64 options.filler_expressions.push_back("an");
65 options.filler_expressions.push_back("one");
66
67 options.half_expressions.push_back("half");
68
69 options.sub_token_separator_codepoints.push_back('-');
70
71 flatbuffers::FlatBufferBuilder builder;
72 builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
73 return new flatbuffers::DetachedBuffer(builder.Release());
74 }
75 } // namespace
76
TestingDurationAnnotatorOptions(ModeFlag enabled_modes)77 const DurationAnnotatorOptions* TestingDurationAnnotatorOptions(
78 ModeFlag enabled_modes) {
79 static const flatbuffers::DetachedBuffer* options_data_all =
80 CreateOptionsData(ModeFlag_ALL);
81 static const flatbuffers::DetachedBuffer* options_data_selection =
82 CreateOptionsData(ModeFlag_SELECTION);
83 static const flatbuffers::DetachedBuffer* options_data_no_selection =
84 CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
85
86 if (enabled_modes == ModeFlag_SELECTION) {
87 return flatbuffers::GetRoot<DurationAnnotatorOptions>(
88 options_data_selection->data());
89 } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
90 return flatbuffers::GetRoot<DurationAnnotatorOptions>(
91 options_data_no_selection->data());
92 } else {
93 return flatbuffers::GetRoot<DurationAnnotatorOptions>(
94 options_data_all->data());
95 }
96 }
97
BuildFeatureProcessor(const UniLib * unilib)98 std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
99 static const flatbuffers::DetachedBuffer* options_data = []() {
100 FeatureProcessorOptionsT options;
101 options.context_size = 1;
102 options.max_selection_span = 1;
103 options.snap_label_span_boundaries_to_containing_tokens = false;
104 options.ignored_span_boundary_codepoints.push_back(',');
105
106 options.tokenization_codepoint_config.emplace_back(
107 new TokenizationCodepointRangeT());
108 auto& config = options.tokenization_codepoint_config.back();
109 config->start = 32;
110 config->end = 33;
111 config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
112
113 flatbuffers::FlatBufferBuilder builder;
114 builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
115 return new flatbuffers::DetachedBuffer(builder.Release());
116 }();
117
118 const FeatureProcessorOptions* feature_processor_options =
119 flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
120
121 return std::unique_ptr<FeatureProcessor>(
122 new FeatureProcessor(feature_processor_options, unilib));
123 }
124
125 class DurationAnnotatorTest : public ::testing::Test {
126 protected:
DurationAnnotatorTest(ModeFlag enabled_modes=ModeFlag_ALL)127 explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
128 : INIT_UNILIB_FOR_TESTING(unilib_),
129 feature_processor_(BuildFeatureProcessor(&unilib_)),
130 duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes),
131 feature_processor_.get(), &unilib_) {}
132
Tokenize(const UnicodeText & text)133 std::vector<Token> Tokenize(const UnicodeText& text) {
134 return feature_processor_->Tokenize(text);
135 }
136
137 UniLib unilib_;
138 std::unique_ptr<FeatureProcessor> feature_processor_;
139 DurationAnnotator duration_annotator_;
140 };
141
142 class DurationAnnotatorForAnnotationAndClassificationTest
143 : public DurationAnnotatorTest {
144 protected:
DurationAnnotatorForAnnotationAndClassificationTest()145 DurationAnnotatorForAnnotationAndClassificationTest()
146 : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
147 };
148
149 class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest {
150 protected:
DurationAnnotatorForSelectionTest()151 DurationAnnotatorForSelectionTest()
152 : DurationAnnotatorTest(ModeFlag_SELECTION) {}
153 };
154
TEST_F(DurationAnnotatorTest,ClassifiesSimpleDuration)155 TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
156 ClassificationResult classification;
157 EXPECT_TRUE(duration_annotator_.ClassifyText(
158 UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
159 AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
160
161 EXPECT_THAT(classification,
162 AllOf(Field(&ClassificationResult::collection, "duration"),
163 Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
164 }
165
TEST_F(DurationAnnotatorForSelectionTest,ClassifyTextDisabledClassificationReturnsFalse)166 TEST_F(DurationAnnotatorForSelectionTest,
167 ClassifyTextDisabledClassificationReturnsFalse) {
168 ClassificationResult classification;
169 EXPECT_FALSE(duration_annotator_.ClassifyText(
170 UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
171 AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
172 }
173
TEST_F(DurationAnnotatorTest,ClassifiesWhenTokensDontAlignWithSelection)174 TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
175 ClassificationResult classification;
176 EXPECT_TRUE(duration_annotator_.ClassifyText(
177 UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
178 AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
179
180 EXPECT_THAT(classification,
181 AllOf(Field(&ClassificationResult::collection, "duration"),
182 Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
183 }
184
TEST_F(DurationAnnotatorTest,DoNotClassifyWhenInputIsInvalid)185 TEST_F(DurationAnnotatorTest, DoNotClassifyWhenInputIsInvalid) {
186 ClassificationResult classification;
187 EXPECT_FALSE(duration_annotator_.ClassifyText(
188 UTF8ToUnicodeText("Weird space"), {5, 6},
189 AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
190 }
191
TEST_F(DurationAnnotatorTest,FindsSimpleDuration)192 TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
193 const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
194 std::vector<Token> tokens = Tokenize(text);
195 std::vector<AnnotatedSpan> result;
196 EXPECT_TRUE(duration_annotator_.FindAll(
197 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
198 ModeFlag_SELECTION, &result));
199
200 EXPECT_THAT(
201 result,
202 ElementsAre(
203 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
204 Field(&AnnotatedSpan::classification,
205 ElementsAre(AllOf(
206 Field(&ClassificationResult::collection, "duration"),
207 Field(&ClassificationResult::duration_ms,
208 15 * 60 * 1000)))))));
209 }
210
TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,FindsAllDisabledModeReturnsNoResults)211 TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,
212 FindsAllDisabledModeReturnsNoResults) {
213 const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
214 std::vector<Token> tokens = Tokenize(text);
215 std::vector<AnnotatedSpan> result;
216 EXPECT_TRUE(duration_annotator_.FindAll(
217 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
218 ModeFlag_SELECTION, &result));
219
220 EXPECT_THAT(result, IsEmpty());
221 }
222
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpression)223 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
224 const UnicodeText text =
225 UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
226 std::vector<Token> tokens = Tokenize(text);
227 std::vector<AnnotatedSpan> result;
228 EXPECT_TRUE(duration_annotator_.FindAll(
229 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
230 ModeFlag_ANNOTATION, &result));
231
232 EXPECT_THAT(
233 result,
234 ElementsAre(
235 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
236 Field(&AnnotatedSpan::classification,
237 ElementsAre(AllOf(
238 Field(&ClassificationResult::collection, "duration"),
239 Field(&ClassificationResult::duration_ms,
240 3.5 * 60 * 1000)))))));
241 }
242
TEST_F(DurationAnnotatorTest,FindsComposedDuration)243 TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
244 const UnicodeText text =
245 UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
246 std::vector<Token> tokens = Tokenize(text);
247 std::vector<AnnotatedSpan> result;
248 EXPECT_TRUE(duration_annotator_.FindAll(
249 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
250 ModeFlag_SELECTION, &result));
251
252 EXPECT_THAT(
253 result,
254 ElementsAre(
255 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
256 Field(&AnnotatedSpan::classification,
257 ElementsAre(AllOf(
258 Field(&ClassificationResult::collection, "duration"),
259 Field(&ClassificationResult::duration_ms,
260 3 * 60 * 60 * 1000 + 5 * 1000)))))));
261 }
262
TEST_F(DurationAnnotatorTest,AllUnitsAreCovered)263 TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
264 const UnicodeText text = UTF8ToUnicodeText(
265 "See you in a week and a day and an hour and a minute and a second");
266 std::vector<Token> tokens = Tokenize(text);
267 std::vector<AnnotatedSpan> result;
268 EXPECT_TRUE(duration_annotator_.FindAll(
269 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
270 ModeFlag_ANNOTATION, &result));
271
272 EXPECT_THAT(
273 result,
274 ElementsAre(
275 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
276 Field(&AnnotatedSpan::classification,
277 ElementsAre(AllOf(
278 Field(&ClassificationResult::collection, "duration"),
279 Field(&ClassificationResult::duration_ms,
280 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
281 60 * 60 * 1000 + 60 * 1000 + 1000)))))));
282 }
283
TEST_F(DurationAnnotatorTest,FindsHalfAnHour)284 TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
285 const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
286 std::vector<Token> tokens = Tokenize(text);
287 std::vector<AnnotatedSpan> result;
288 EXPECT_TRUE(duration_annotator_.FindAll(
289 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
290 ModeFlag_ANNOTATION, &result));
291
292 EXPECT_THAT(
293 result,
294 ElementsAre(
295 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
296 Field(&AnnotatedSpan::classification,
297 ElementsAre(AllOf(
298 Field(&ClassificationResult::collection, "duration"),
299 Field(&ClassificationResult::duration_ms,
300 0.5 * 60 * 60 * 1000)))))));
301 }
302
TEST_F(DurationAnnotatorTest,FindsWhenHalfIsAfterGranularitySpecification)303 TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
304 const UnicodeText text =
305 UTF8ToUnicodeText("Set a timer for 1 hour and a half");
306 std::vector<Token> tokens = Tokenize(text);
307 std::vector<AnnotatedSpan> result;
308 EXPECT_TRUE(duration_annotator_.FindAll(
309 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
310 ModeFlag_SELECTION, &result));
311
312 EXPECT_THAT(
313 result,
314 ElementsAre(
315 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
316 Field(&AnnotatedSpan::classification,
317 ElementsAre(AllOf(
318 Field(&ClassificationResult::collection, "duration"),
319 Field(&ClassificationResult::duration_ms,
320 1.5 * 60 * 60 * 1000)))))));
321 }
322
TEST_F(DurationAnnotatorTest,FindsAnHourAndAHalf)323 TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
324 const UnicodeText text =
325 UTF8ToUnicodeText("Set a timer for an hour and a half");
326 std::vector<Token> tokens = Tokenize(text);
327 std::vector<AnnotatedSpan> result;
328 EXPECT_TRUE(duration_annotator_.FindAll(
329 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
330 ModeFlag_ANNOTATION, &result));
331
332 EXPECT_THAT(
333 result,
334 ElementsAre(
335 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
336 Field(&AnnotatedSpan::classification,
337 ElementsAre(AllOf(
338 Field(&ClassificationResult::collection, "duration"),
339 Field(&ClassificationResult::duration_ms,
340 1.5 * 60 * 60 * 1000)))))));
341 }
342
TEST_F(DurationAnnotatorTest,FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber)343 TEST_F(DurationAnnotatorTest,
344 FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
345 const UnicodeText text =
346 UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
347 std::vector<Token> tokens = Tokenize(text);
348 std::vector<AnnotatedSpan> result;
349 EXPECT_TRUE(duration_annotator_.FindAll(
350 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
351 ModeFlag_ANNOTATION, &result));
352
353 EXPECT_THAT(
354 result,
355 ElementsAre(
356 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
357 Field(&AnnotatedSpan::classification,
358 ElementsAre(AllOf(
359 Field(&ClassificationResult::collection, "duration"),
360 Field(&ClassificationResult::duration_ms,
361 10 * 60 * 1000 + 1 * 1000)))))));
362 }
363
TEST_F(DurationAnnotatorTest,DoesNotGreedilyTakeFillerWords)364 TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
365 const UnicodeText text = UTF8ToUnicodeText(
366 "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
367 std::vector<Token> tokens = Tokenize(text);
368 std::vector<AnnotatedSpan> result;
369 EXPECT_TRUE(duration_annotator_.FindAll(
370 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
371 ModeFlag_ANNOTATION, &result));
372
373 EXPECT_THAT(
374 result,
375 ElementsAre(
376 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
377 Field(&AnnotatedSpan::classification,
378 ElementsAre(AllOf(
379 Field(&ClassificationResult::collection, "duration"),
380 Field(&ClassificationResult::duration_ms,
381 10 * 60 * 1000 + 2 * 1000)))))));
382 }
383
TEST_F(DurationAnnotatorTest,DoesNotCrashWhenJustHalfIsSaid)384 TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
385 const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
386 std::vector<Token> tokens = Tokenize(text);
387 std::vector<AnnotatedSpan> result;
388 EXPECT_TRUE(duration_annotator_.FindAll(
389 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
390 ModeFlag_ANNOTATION, &result));
391
392 ASSERT_EQ(result.size(), 0);
393 }
394
TEST_F(DurationAnnotatorTest,StripsPunctuationFromTokens)395 TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
396 const UnicodeText text =
397 UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
398 std::vector<Token> tokens = Tokenize(text);
399 std::vector<AnnotatedSpan> result;
400 EXPECT_TRUE(duration_annotator_.FindAll(
401 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
402 ModeFlag_ANNOTATION, &result));
403
404 EXPECT_THAT(
405 result,
406 ElementsAre(
407 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
408 Field(&AnnotatedSpan::classification,
409 ElementsAre(AllOf(
410 Field(&ClassificationResult::collection, "duration"),
411 Field(&ClassificationResult::duration_ms,
412 10 * 60 * 1000 + 2 * 1000)))))));
413 }
414
TEST_F(DurationAnnotatorTest,FindsCorrectlyWithCombinedQuantityUnitToken)415 TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
416 const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
417 std::vector<Token> tokens = Tokenize(text);
418 std::vector<AnnotatedSpan> result;
419 EXPECT_TRUE(duration_annotator_.FindAll(
420 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
421 ModeFlag_ANNOTATION, &result));
422
423 EXPECT_THAT(
424 result,
425 ElementsAre(
426 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
427 Field(&AnnotatedSpan::classification,
428 ElementsAre(AllOf(
429 Field(&ClassificationResult::collection, "duration"),
430 Field(&ClassificationResult::duration_ms,
431 5 * 60 * 1000)))))));
432 }
433
TEST_F(DurationAnnotatorTest,DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis)434 TEST_F(DurationAnnotatorTest,
435 DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
436 ClassificationResult classification;
437 EXPECT_TRUE(duration_annotator_.ClassifyText(
438 UTF8ToUnicodeText("1400 hours"), {0, 10},
439 AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
440
441 EXPECT_THAT(classification,
442 AllOf(Field(&ClassificationResult::collection, "duration"),
443 Field(&ClassificationResult::duration_ms,
444 1400LL * 60LL * 60LL * 1000LL)));
445 }
446
TEST_F(DurationAnnotatorTest,FindsSimpleDurationIgnoringCase)447 TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
448 const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
449 std::vector<Token> tokens = Tokenize(text);
450 std::vector<AnnotatedSpan> result;
451 EXPECT_TRUE(duration_annotator_.FindAll(
452 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
453 ModeFlag_ANNOTATION, &result));
454
455 EXPECT_THAT(
456 result,
457 ElementsAre(
458 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
459 Field(&AnnotatedSpan::classification,
460 ElementsAre(AllOf(
461 Field(&ClassificationResult::collection, "duration"),
462 Field(&ClassificationResult::duration_ms,
463 15 * 60 * 1000)))))));
464 }
465
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringCase)466 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
467 const UnicodeText text =
468 UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
469 std::vector<Token> tokens = Tokenize(text);
470 std::vector<AnnotatedSpan> result;
471 EXPECT_TRUE(duration_annotator_.FindAll(
472 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
473 ModeFlag_ANNOTATION, &result));
474
475 EXPECT_THAT(
476 result,
477 ElementsAre(
478 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
479 Field(&AnnotatedSpan::classification,
480 ElementsAre(AllOf(
481 Field(&ClassificationResult::collection, "duration"),
482 Field(&ClassificationResult::duration_ms,
483 3.5 * 60 * 1000)))))));
484 }
485
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringFillerWordCase)486 TEST_F(DurationAnnotatorTest,
487 FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
488 const UnicodeText text =
489 UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
490 std::vector<Token> tokens = Tokenize(text);
491 std::vector<AnnotatedSpan> result;
492 EXPECT_TRUE(duration_annotator_.FindAll(
493 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
494 ModeFlag_ANNOTATION, &result));
495
496 EXPECT_THAT(
497 result,
498 ElementsAre(
499 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
500 Field(&AnnotatedSpan::classification,
501 ElementsAre(AllOf(
502 Field(&ClassificationResult::collection, "duration"),
503 Field(&ClassificationResult::duration_ms,
504 3.5 * 60 * 1000)))))));
505 }
506
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantity)507 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
508 const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
509 std::vector<Token> tokens = Tokenize(text);
510 std::vector<AnnotatedSpan> result;
511 EXPECT_TRUE(duration_annotator_.FindAll(
512 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
513 ModeFlag_ANNOTATION, &result));
514
515 EXPECT_THAT(
516 result,
517 ElementsAre(
518 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
519 Field(&AnnotatedSpan::classification,
520 ElementsAre(AllOf(
521 Field(&ClassificationResult::collection, "duration"),
522 Field(&ClassificationResult::duration_ms,
523 20 * 60 * 1000 + 10 * 1000)))))));
524 }
525
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantityNotSupported)526 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
527 const UnicodeText text = UTF8ToUnicodeText("20 seconds 10");
528 std::vector<Token> tokens = Tokenize(text);
529 std::vector<AnnotatedSpan> result;
530 EXPECT_TRUE(duration_annotator_.FindAll(
531 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
532 ModeFlag_ANNOTATION, &result));
533
534 EXPECT_THAT(
535 result,
536 ElementsAre(AllOf(
537 Field(&AnnotatedSpan::span, CodepointSpan(0, 10)),
538 Field(&AnnotatedSpan::classification,
539 ElementsAre(AllOf(
540 Field(&ClassificationResult::collection, "duration"),
541 Field(&ClassificationResult::duration_ms, 20 * 1000)))))));
542 }
543
TEST_F(DurationAnnotatorTest,FindsDurationWithDecimalQuantity)544 TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
545 const UnicodeText text = UTF8ToUnicodeText("in 10.2 hours");
546 std::vector<Token> tokens = Tokenize(text);
547 std::vector<AnnotatedSpan> result;
548 EXPECT_TRUE(duration_annotator_.FindAll(
549 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
550 ModeFlag_ANNOTATION, &result));
551
552 EXPECT_THAT(
553 result,
554 ElementsAre(
555 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(3, 13)),
556 Field(&AnnotatedSpan::classification,
557 ElementsAre(AllOf(
558 Field(&ClassificationResult::collection, "duration"),
559 Field(&ClassificationResult::duration_ms,
560 10 * 60 * 60 * 1000 + 12 * 60 * 1000)))))));
561 }
562
TestingJapaneseDurationAnnotatorOptions()563 const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
564 static const flatbuffers::DetachedBuffer* options_data = []() {
565 DurationAnnotatorOptionsT options;
566 options.enabled = true;
567
568 options.week_expressions.push_back("週間");
569
570 options.day_expressions.push_back("日間");
571
572 options.hour_expressions.push_back("時間");
573
574 options.minute_expressions.push_back("分");
575 options.minute_expressions.push_back("分間");
576
577 options.second_expressions.push_back("秒");
578 options.second_expressions.push_back("秒間");
579
580 options.half_expressions.push_back("半");
581
582 options.require_quantity = true;
583 options.enable_dangling_quantity_interpretation = true;
584
585 flatbuffers::FlatBufferBuilder builder;
586 builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
587 return new flatbuffers::DetachedBuffer(builder.Release());
588 }();
589
590 return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
591 }
592
593 class JapaneseDurationAnnotatorTest : public ::testing::Test {
594 protected:
JapaneseDurationAnnotatorTest()595 JapaneseDurationAnnotatorTest()
596 : INIT_UNILIB_FOR_TESTING(unilib_),
597 feature_processor_(BuildFeatureProcessor(&unilib_)),
598 duration_annotator_(TestingJapaneseDurationAnnotatorOptions(),
599 feature_processor_.get(), &unilib_) {}
600
Tokenize(const UnicodeText & text)601 std::vector<Token> Tokenize(const UnicodeText& text) {
602 return feature_processor_->Tokenize(text);
603 }
604
605 UniLib unilib_;
606 std::unique_ptr<FeatureProcessor> feature_processor_;
607 DurationAnnotator duration_annotator_;
608 };
609
TEST_F(JapaneseDurationAnnotatorTest,FindsDuration)610 TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
611 const UnicodeText text = UTF8ToUnicodeText("10 分 の アラーム");
612 std::vector<Token> tokens = Tokenize(text);
613 std::vector<AnnotatedSpan> result;
614 EXPECT_TRUE(duration_annotator_.FindAll(
615 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
616 ModeFlag_ANNOTATION, &result));
617
618 EXPECT_THAT(
619 result,
620 ElementsAre(
621 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 4)),
622 Field(&AnnotatedSpan::classification,
623 ElementsAre(AllOf(
624 Field(&ClassificationResult::collection, "duration"),
625 Field(&ClassificationResult::duration_ms,
626 10 * 60 * 1000)))))));
627 }
628
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithHalfExpression)629 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
630 const UnicodeText text = UTF8ToUnicodeText("2 分 半 の アラーム");
631 std::vector<Token> tokens = Tokenize(text);
632 std::vector<AnnotatedSpan> result;
633 EXPECT_TRUE(duration_annotator_.FindAll(
634 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
635 ModeFlag_ANNOTATION, &result));
636
637 EXPECT_THAT(
638 result,
639 ElementsAre(
640 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 5)),
641 Field(&AnnotatedSpan::classification,
642 ElementsAre(AllOf(
643 Field(&ClassificationResult::collection, "duration"),
644 Field(&ClassificationResult::duration_ms,
645 2.5 * 60 * 1000)))))));
646 }
647
TEST_F(JapaneseDurationAnnotatorTest,IgnoresDurationWithoutQuantity)648 TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
649 const UnicodeText text = UTF8ToUnicodeText("分 の アラーム");
650 std::vector<Token> tokens = Tokenize(text);
651 std::vector<AnnotatedSpan> result;
652 EXPECT_TRUE(duration_annotator_.FindAll(
653 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
654 ModeFlag_ANNOTATION, &result));
655
656 EXPECT_THAT(result, IsEmpty());
657 }
658
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithDanglingQuantity)659 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
660 const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
661 std::vector<Token> tokens = Tokenize(text);
662 std::vector<AnnotatedSpan> result;
663 EXPECT_TRUE(duration_annotator_.FindAll(
664 text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
665 ModeFlag_SELECTION, &result));
666
667 EXPECT_THAT(
668 result,
669 ElementsAre(
670 AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 6)),
671 Field(&AnnotatedSpan::classification,
672 ElementsAre(AllOf(
673 Field(&ClassificationResult::collection, "duration"),
674 Field(&ClassificationResult::duration_ms,
675 2 * 60 * 1000 + 10 * 1000)))))));
676 }
677
678 } // namespace
679 } // namespace libtextclassifier3
680