xref: /aosp_15_r20/external/libtextclassifier/native/annotator/duration/duration_test.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "annotator/duration/duration.h"
18 
19 #include <cstddef>
20 #include <string>
21 #include <vector>
22 
23 #include "annotator/collections.h"
24 #include "annotator/model_generated.h"
25 #include "annotator/types-test-util.h"
26 #include "annotator/types.h"
27 #include "utils/tokenizer-utils.h"
28 #include "utils/utf8/unicodetext.h"
29 #include "utils/utf8/unilib.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32 
33 namespace libtextclassifier3 {
34 namespace {
35 
36 using testing::AllOf;
37 using testing::ElementsAre;
38 using testing::Field;
39 using testing::IsEmpty;
40 
41 namespace {
CreateOptionsData(ModeFlag enabled_modes)42 const flatbuffers::DetachedBuffer* CreateOptionsData(ModeFlag enabled_modes) {
43   DurationAnnotatorOptionsT options;
44   options.enabled = true;
45   options.enabled_modes = enabled_modes;
46 
47   options.week_expressions.push_back("week");
48   options.week_expressions.push_back("weeks");
49 
50   options.day_expressions.push_back("day");
51   options.day_expressions.push_back("days");
52 
53   options.hour_expressions.push_back("hour");
54   options.hour_expressions.push_back("hours");
55 
56   options.minute_expressions.push_back("minute");
57   options.minute_expressions.push_back("minutes");
58 
59   options.second_expressions.push_back("second");
60   options.second_expressions.push_back("seconds");
61 
62   options.filler_expressions.push_back("and");
63   options.filler_expressions.push_back("a");
64   options.filler_expressions.push_back("an");
65   options.filler_expressions.push_back("one");
66 
67   options.half_expressions.push_back("half");
68 
69   options.sub_token_separator_codepoints.push_back('-');
70 
71   flatbuffers::FlatBufferBuilder builder;
72   builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
73   return new flatbuffers::DetachedBuffer(builder.Release());
74 }
75 }  // namespace
76 
TestingDurationAnnotatorOptions(ModeFlag enabled_modes)77 const DurationAnnotatorOptions* TestingDurationAnnotatorOptions(
78     ModeFlag enabled_modes) {
79   static const flatbuffers::DetachedBuffer* options_data_all =
80       CreateOptionsData(ModeFlag_ALL);
81   static const flatbuffers::DetachedBuffer* options_data_selection =
82       CreateOptionsData(ModeFlag_SELECTION);
83   static const flatbuffers::DetachedBuffer* options_data_no_selection =
84       CreateOptionsData(ModeFlag_ANNOTATION_AND_CLASSIFICATION);
85 
86   if (enabled_modes == ModeFlag_SELECTION) {
87     return flatbuffers::GetRoot<DurationAnnotatorOptions>(
88         options_data_selection->data());
89   } else if (enabled_modes == ModeFlag_ANNOTATION_AND_CLASSIFICATION) {
90     return flatbuffers::GetRoot<DurationAnnotatorOptions>(
91         options_data_no_selection->data());
92   } else {
93     return flatbuffers::GetRoot<DurationAnnotatorOptions>(
94         options_data_all->data());
95   }
96 }
97 
BuildFeatureProcessor(const UniLib * unilib)98 std::unique_ptr<FeatureProcessor> BuildFeatureProcessor(const UniLib* unilib) {
99   static const flatbuffers::DetachedBuffer* options_data = []() {
100     FeatureProcessorOptionsT options;
101     options.context_size = 1;
102     options.max_selection_span = 1;
103     options.snap_label_span_boundaries_to_containing_tokens = false;
104     options.ignored_span_boundary_codepoints.push_back(',');
105 
106     options.tokenization_codepoint_config.emplace_back(
107         new TokenizationCodepointRangeT());
108     auto& config = options.tokenization_codepoint_config.back();
109     config->start = 32;
110     config->end = 33;
111     config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
112 
113     flatbuffers::FlatBufferBuilder builder;
114     builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
115     return new flatbuffers::DetachedBuffer(builder.Release());
116   }();
117 
118   const FeatureProcessorOptions* feature_processor_options =
119       flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
120 
121   return std::unique_ptr<FeatureProcessor>(
122       new FeatureProcessor(feature_processor_options, unilib));
123 }
124 
125 class DurationAnnotatorTest : public ::testing::Test {
126  protected:
DurationAnnotatorTest(ModeFlag enabled_modes=ModeFlag_ALL)127   explicit DurationAnnotatorTest(ModeFlag enabled_modes = ModeFlag_ALL)
128       : INIT_UNILIB_FOR_TESTING(unilib_),
129         feature_processor_(BuildFeatureProcessor(&unilib_)),
130         duration_annotator_(TestingDurationAnnotatorOptions(enabled_modes),
131                             feature_processor_.get(), &unilib_) {}
132 
Tokenize(const UnicodeText & text)133   std::vector<Token> Tokenize(const UnicodeText& text) {
134     return feature_processor_->Tokenize(text);
135   }
136 
137   UniLib unilib_;
138   std::unique_ptr<FeatureProcessor> feature_processor_;
139   DurationAnnotator duration_annotator_;
140 };
141 
142 class DurationAnnotatorForAnnotationAndClassificationTest
143     : public DurationAnnotatorTest {
144  protected:
DurationAnnotatorForAnnotationAndClassificationTest()145   DurationAnnotatorForAnnotationAndClassificationTest()
146       : DurationAnnotatorTest(ModeFlag_ANNOTATION_AND_CLASSIFICATION) {}
147 };
148 
149 class DurationAnnotatorForSelectionTest : public DurationAnnotatorTest {
150  protected:
DurationAnnotatorForSelectionTest()151   DurationAnnotatorForSelectionTest()
152       : DurationAnnotatorTest(ModeFlag_SELECTION) {}
153 };
154 
TEST_F(DurationAnnotatorTest,ClassifiesSimpleDuration)155 TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
156   ClassificationResult classification;
157   EXPECT_TRUE(duration_annotator_.ClassifyText(
158       UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
159       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
160 
161   EXPECT_THAT(classification,
162               AllOf(Field(&ClassificationResult::collection, "duration"),
163                     Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
164 }
165 
TEST_F(DurationAnnotatorForSelectionTest,ClassifyTextDisabledClassificationReturnsFalse)166 TEST_F(DurationAnnotatorForSelectionTest,
167        ClassifyTextDisabledClassificationReturnsFalse) {
168   ClassificationResult classification;
169   EXPECT_FALSE(duration_annotator_.ClassifyText(
170       UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
171       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
172 }
173 
TEST_F(DurationAnnotatorTest,ClassifiesWhenTokensDontAlignWithSelection)174 TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
175   ClassificationResult classification;
176   EXPECT_TRUE(duration_annotator_.ClassifyText(
177       UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
178       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
179 
180   EXPECT_THAT(classification,
181               AllOf(Field(&ClassificationResult::collection, "duration"),
182                     Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
183 }
184 
TEST_F(DurationAnnotatorTest,DoNotClassifyWhenInputIsInvalid)185 TEST_F(DurationAnnotatorTest, DoNotClassifyWhenInputIsInvalid) {
186   ClassificationResult classification;
187   EXPECT_FALSE(duration_annotator_.ClassifyText(
188       UTF8ToUnicodeText("Weird space"), {5, 6},
189       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
190 }
191 
TEST_F(DurationAnnotatorTest,FindsSimpleDuration)192 TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
193   const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
194   std::vector<Token> tokens = Tokenize(text);
195   std::vector<AnnotatedSpan> result;
196   EXPECT_TRUE(duration_annotator_.FindAll(
197       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
198       ModeFlag_SELECTION, &result));
199 
200   EXPECT_THAT(
201       result,
202       ElementsAre(
203           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
204                 Field(&AnnotatedSpan::classification,
205                       ElementsAre(AllOf(
206                           Field(&ClassificationResult::collection, "duration"),
207                           Field(&ClassificationResult::duration_ms,
208                                 15 * 60 * 1000)))))));
209 }
210 
TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,FindsAllDisabledModeReturnsNoResults)211 TEST_F(DurationAnnotatorForAnnotationAndClassificationTest,
212        FindsAllDisabledModeReturnsNoResults) {
213   const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
214   std::vector<Token> tokens = Tokenize(text);
215   std::vector<AnnotatedSpan> result;
216   EXPECT_TRUE(duration_annotator_.FindAll(
217       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
218       ModeFlag_SELECTION, &result));
219 
220   EXPECT_THAT(result, IsEmpty());
221 }
222 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpression)223 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
224   const UnicodeText text =
225       UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
226   std::vector<Token> tokens = Tokenize(text);
227   std::vector<AnnotatedSpan> result;
228   EXPECT_TRUE(duration_annotator_.FindAll(
229       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
230       ModeFlag_ANNOTATION, &result));
231 
232   EXPECT_THAT(
233       result,
234       ElementsAre(
235           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
236                 Field(&AnnotatedSpan::classification,
237                       ElementsAre(AllOf(
238                           Field(&ClassificationResult::collection, "duration"),
239                           Field(&ClassificationResult::duration_ms,
240                                 3.5 * 60 * 1000)))))));
241 }
242 
TEST_F(DurationAnnotatorTest,FindsComposedDuration)243 TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
244   const UnicodeText text =
245       UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
246   std::vector<Token> tokens = Tokenize(text);
247   std::vector<AnnotatedSpan> result;
248   EXPECT_TRUE(duration_annotator_.FindAll(
249       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
250       ModeFlag_SELECTION, &result));
251 
252   EXPECT_THAT(
253       result,
254       ElementsAre(
255           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
256                 Field(&AnnotatedSpan::classification,
257                       ElementsAre(AllOf(
258                           Field(&ClassificationResult::collection, "duration"),
259                           Field(&ClassificationResult::duration_ms,
260                                 3 * 60 * 60 * 1000 + 5 * 1000)))))));
261 }
262 
TEST_F(DurationAnnotatorTest,AllUnitsAreCovered)263 TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
264   const UnicodeText text = UTF8ToUnicodeText(
265       "See you in a week and a day and an hour and a minute and a second");
266   std::vector<Token> tokens = Tokenize(text);
267   std::vector<AnnotatedSpan> result;
268   EXPECT_TRUE(duration_annotator_.FindAll(
269       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
270       ModeFlag_ANNOTATION, &result));
271 
272   EXPECT_THAT(
273       result,
274       ElementsAre(
275           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
276                 Field(&AnnotatedSpan::classification,
277                       ElementsAre(AllOf(
278                           Field(&ClassificationResult::collection, "duration"),
279                           Field(&ClassificationResult::duration_ms,
280                                 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
281                                     60 * 60 * 1000 + 60 * 1000 + 1000)))))));
282 }
283 
TEST_F(DurationAnnotatorTest,FindsHalfAnHour)284 TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
285   const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
286   std::vector<Token> tokens = Tokenize(text);
287   std::vector<AnnotatedSpan> result;
288   EXPECT_TRUE(duration_annotator_.FindAll(
289       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
290       ModeFlag_ANNOTATION, &result));
291 
292   EXPECT_THAT(
293       result,
294       ElementsAre(
295           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
296                 Field(&AnnotatedSpan::classification,
297                       ElementsAre(AllOf(
298                           Field(&ClassificationResult::collection, "duration"),
299                           Field(&ClassificationResult::duration_ms,
300                                 0.5 * 60 * 60 * 1000)))))));
301 }
302 
TEST_F(DurationAnnotatorTest,FindsWhenHalfIsAfterGranularitySpecification)303 TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
304   const UnicodeText text =
305       UTF8ToUnicodeText("Set a timer for 1 hour and a half");
306   std::vector<Token> tokens = Tokenize(text);
307   std::vector<AnnotatedSpan> result;
308   EXPECT_TRUE(duration_annotator_.FindAll(
309       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
310       ModeFlag_SELECTION, &result));
311 
312   EXPECT_THAT(
313       result,
314       ElementsAre(
315           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
316                 Field(&AnnotatedSpan::classification,
317                       ElementsAre(AllOf(
318                           Field(&ClassificationResult::collection, "duration"),
319                           Field(&ClassificationResult::duration_ms,
320                                 1.5 * 60 * 60 * 1000)))))));
321 }
322 
TEST_F(DurationAnnotatorTest,FindsAnHourAndAHalf)323 TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
324   const UnicodeText text =
325       UTF8ToUnicodeText("Set a timer for an hour and a half");
326   std::vector<Token> tokens = Tokenize(text);
327   std::vector<AnnotatedSpan> result;
328   EXPECT_TRUE(duration_annotator_.FindAll(
329       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
330       ModeFlag_ANNOTATION, &result));
331 
332   EXPECT_THAT(
333       result,
334       ElementsAre(
335           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
336                 Field(&AnnotatedSpan::classification,
337                       ElementsAre(AllOf(
338                           Field(&ClassificationResult::collection, "duration"),
339                           Field(&ClassificationResult::duration_ms,
340                                 1.5 * 60 * 60 * 1000)))))));
341 }
342 
TEST_F(DurationAnnotatorTest,FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber)343 TEST_F(DurationAnnotatorTest,
344        FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
345   const UnicodeText text =
346       UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
347   std::vector<Token> tokens = Tokenize(text);
348   std::vector<AnnotatedSpan> result;
349   EXPECT_TRUE(duration_annotator_.FindAll(
350       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
351       ModeFlag_ANNOTATION, &result));
352 
353   EXPECT_THAT(
354       result,
355       ElementsAre(
356           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
357                 Field(&AnnotatedSpan::classification,
358                       ElementsAre(AllOf(
359                           Field(&ClassificationResult::collection, "duration"),
360                           Field(&ClassificationResult::duration_ms,
361                                 10 * 60 * 1000 + 1 * 1000)))))));
362 }
363 
TEST_F(DurationAnnotatorTest,DoesNotGreedilyTakeFillerWords)364 TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
365   const UnicodeText text = UTF8ToUnicodeText(
366       "Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
367   std::vector<Token> tokens = Tokenize(text);
368   std::vector<AnnotatedSpan> result;
369   EXPECT_TRUE(duration_annotator_.FindAll(
370       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
371       ModeFlag_ANNOTATION, &result));
372 
373   EXPECT_THAT(
374       result,
375       ElementsAre(
376           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
377                 Field(&AnnotatedSpan::classification,
378                       ElementsAre(AllOf(
379                           Field(&ClassificationResult::collection, "duration"),
380                           Field(&ClassificationResult::duration_ms,
381                                 10 * 60 * 1000 + 2 * 1000)))))));
382 }
383 
TEST_F(DurationAnnotatorTest,DoesNotCrashWhenJustHalfIsSaid)384 TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
385   const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
386   std::vector<Token> tokens = Tokenize(text);
387   std::vector<AnnotatedSpan> result;
388   EXPECT_TRUE(duration_annotator_.FindAll(
389       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
390       ModeFlag_ANNOTATION, &result));
391 
392   ASSERT_EQ(result.size(), 0);
393 }
394 
TEST_F(DurationAnnotatorTest,StripsPunctuationFromTokens)395 TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
396   const UnicodeText text =
397       UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
398   std::vector<Token> tokens = Tokenize(text);
399   std::vector<AnnotatedSpan> result;
400   EXPECT_TRUE(duration_annotator_.FindAll(
401       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
402       ModeFlag_ANNOTATION, &result));
403 
404   EXPECT_THAT(
405       result,
406       ElementsAre(
407           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
408                 Field(&AnnotatedSpan::classification,
409                       ElementsAre(AllOf(
410                           Field(&ClassificationResult::collection, "duration"),
411                           Field(&ClassificationResult::duration_ms,
412                                 10 * 60 * 1000 + 2 * 1000)))))));
413 }
414 
TEST_F(DurationAnnotatorTest,FindsCorrectlyWithCombinedQuantityUnitToken)415 TEST_F(DurationAnnotatorTest, FindsCorrectlyWithCombinedQuantityUnitToken) {
416   const UnicodeText text = UTF8ToUnicodeText("Show 5-minute timer.");
417   std::vector<Token> tokens = Tokenize(text);
418   std::vector<AnnotatedSpan> result;
419   EXPECT_TRUE(duration_annotator_.FindAll(
420       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
421       ModeFlag_ANNOTATION, &result));
422 
423   EXPECT_THAT(
424       result,
425       ElementsAre(
426           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(5, 13)),
427                 Field(&AnnotatedSpan::classification,
428                       ElementsAre(AllOf(
429                           Field(&ClassificationResult::collection, "duration"),
430                           Field(&ClassificationResult::duration_ms,
431                                 5 * 60 * 1000)))))));
432 }
433 
TEST_F(DurationAnnotatorTest,DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis)434 TEST_F(DurationAnnotatorTest,
435        DoesNotIntOverflowWithDurationThatHasMoreThanInt32Millis) {
436   ClassificationResult classification;
437   EXPECT_TRUE(duration_annotator_.ClassifyText(
438       UTF8ToUnicodeText("1400 hours"), {0, 10},
439       AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
440 
441   EXPECT_THAT(classification,
442               AllOf(Field(&ClassificationResult::collection, "duration"),
443                     Field(&ClassificationResult::duration_ms,
444                           1400LL * 60LL * 60LL * 1000LL)));
445 }
446 
TEST_F(DurationAnnotatorTest,FindsSimpleDurationIgnoringCase)447 TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
448   const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
449   std::vector<Token> tokens = Tokenize(text);
450   std::vector<AnnotatedSpan> result;
451   EXPECT_TRUE(duration_annotator_.FindAll(
452       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
453       ModeFlag_ANNOTATION, &result));
454 
455   EXPECT_THAT(
456       result,
457       ElementsAre(
458           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
459                 Field(&AnnotatedSpan::classification,
460                       ElementsAre(AllOf(
461                           Field(&ClassificationResult::collection, "duration"),
462                           Field(&ClassificationResult::duration_ms,
463                                 15 * 60 * 1000)))))));
464 }
465 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringCase)466 TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
467   const UnicodeText text =
468       UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
469   std::vector<Token> tokens = Tokenize(text);
470   std::vector<AnnotatedSpan> result;
471   EXPECT_TRUE(duration_annotator_.FindAll(
472       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
473       ModeFlag_ANNOTATION, &result));
474 
475   EXPECT_THAT(
476       result,
477       ElementsAre(
478           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
479                 Field(&AnnotatedSpan::classification,
480                       ElementsAre(AllOf(
481                           Field(&ClassificationResult::collection, "duration"),
482                           Field(&ClassificationResult::duration_ms,
483                                 3.5 * 60 * 1000)))))));
484 }
485 
TEST_F(DurationAnnotatorTest,FindsDurationWithHalfExpressionIgnoringFillerWordCase)486 TEST_F(DurationAnnotatorTest,
487        FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
488   const UnicodeText text =
489       UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
490   std::vector<Token> tokens = Tokenize(text);
491   std::vector<AnnotatedSpan> result;
492   EXPECT_TRUE(duration_annotator_.FindAll(
493       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
494       ModeFlag_ANNOTATION, &result));
495 
496   EXPECT_THAT(
497       result,
498       ElementsAre(
499           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
500                 Field(&AnnotatedSpan::classification,
501                       ElementsAre(AllOf(
502                           Field(&ClassificationResult::collection, "duration"),
503                           Field(&ClassificationResult::duration_ms,
504                                 3.5 * 60 * 1000)))))));
505 }
506 
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantity)507 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
508   const UnicodeText text = UTF8ToUnicodeText("20 minutes 10");
509   std::vector<Token> tokens = Tokenize(text);
510   std::vector<AnnotatedSpan> result;
511   EXPECT_TRUE(duration_annotator_.FindAll(
512       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
513       ModeFlag_ANNOTATION, &result));
514 
515   EXPECT_THAT(
516       result,
517       ElementsAre(
518           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 13)),
519                 Field(&AnnotatedSpan::classification,
520                       ElementsAre(AllOf(
521                           Field(&ClassificationResult::collection, "duration"),
522                           Field(&ClassificationResult::duration_ms,
523                                 20 * 60 * 1000 + 10 * 1000)))))));
524 }
525 
TEST_F(DurationAnnotatorTest,FindsDurationWithDanglingQuantityNotSupported)526 TEST_F(DurationAnnotatorTest, FindsDurationWithDanglingQuantityNotSupported) {
527   const UnicodeText text = UTF8ToUnicodeText("20 seconds 10");
528   std::vector<Token> tokens = Tokenize(text);
529   std::vector<AnnotatedSpan> result;
530   EXPECT_TRUE(duration_annotator_.FindAll(
531       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
532       ModeFlag_ANNOTATION, &result));
533 
534   EXPECT_THAT(
535       result,
536       ElementsAre(AllOf(
537           Field(&AnnotatedSpan::span, CodepointSpan(0, 10)),
538           Field(&AnnotatedSpan::classification,
539                 ElementsAre(AllOf(
540                     Field(&ClassificationResult::collection, "duration"),
541                     Field(&ClassificationResult::duration_ms, 20 * 1000)))))));
542 }
543 
TEST_F(DurationAnnotatorTest,FindsDurationWithDecimalQuantity)544 TEST_F(DurationAnnotatorTest, FindsDurationWithDecimalQuantity) {
545   const UnicodeText text = UTF8ToUnicodeText("in 10.2 hours");
546   std::vector<Token> tokens = Tokenize(text);
547   std::vector<AnnotatedSpan> result;
548   EXPECT_TRUE(duration_annotator_.FindAll(
549       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
550       ModeFlag_ANNOTATION, &result));
551 
552   EXPECT_THAT(
553       result,
554       ElementsAre(
555           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(3, 13)),
556                 Field(&AnnotatedSpan::classification,
557                       ElementsAre(AllOf(
558                           Field(&ClassificationResult::collection, "duration"),
559                           Field(&ClassificationResult::duration_ms,
560                                 10 * 60 * 60 * 1000 + 12 * 60 * 1000)))))));
561 }
562 
TestingJapaneseDurationAnnotatorOptions()563 const DurationAnnotatorOptions* TestingJapaneseDurationAnnotatorOptions() {
564   static const flatbuffers::DetachedBuffer* options_data = []() {
565     DurationAnnotatorOptionsT options;
566     options.enabled = true;
567 
568     options.week_expressions.push_back("週間");
569 
570     options.day_expressions.push_back("日間");
571 
572     options.hour_expressions.push_back("時間");
573 
574     options.minute_expressions.push_back("分");
575     options.minute_expressions.push_back("分間");
576 
577     options.second_expressions.push_back("秒");
578     options.second_expressions.push_back("秒間");
579 
580     options.half_expressions.push_back("半");
581 
582     options.require_quantity = true;
583     options.enable_dangling_quantity_interpretation = true;
584 
585     flatbuffers::FlatBufferBuilder builder;
586     builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
587     return new flatbuffers::DetachedBuffer(builder.Release());
588   }();
589 
590   return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
591 }
592 
593 class JapaneseDurationAnnotatorTest : public ::testing::Test {
594  protected:
JapaneseDurationAnnotatorTest()595   JapaneseDurationAnnotatorTest()
596       : INIT_UNILIB_FOR_TESTING(unilib_),
597         feature_processor_(BuildFeatureProcessor(&unilib_)),
598         duration_annotator_(TestingJapaneseDurationAnnotatorOptions(),
599                             feature_processor_.get(), &unilib_) {}
600 
Tokenize(const UnicodeText & text)601   std::vector<Token> Tokenize(const UnicodeText& text) {
602     return feature_processor_->Tokenize(text);
603   }
604 
605   UniLib unilib_;
606   std::unique_ptr<FeatureProcessor> feature_processor_;
607   DurationAnnotator duration_annotator_;
608 };
609 
TEST_F(JapaneseDurationAnnotatorTest,FindsDuration)610 TEST_F(JapaneseDurationAnnotatorTest, FindsDuration) {
611   const UnicodeText text = UTF8ToUnicodeText("10 分 の アラーム");
612   std::vector<Token> tokens = Tokenize(text);
613   std::vector<AnnotatedSpan> result;
614   EXPECT_TRUE(duration_annotator_.FindAll(
615       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
616       ModeFlag_ANNOTATION, &result));
617 
618   EXPECT_THAT(
619       result,
620       ElementsAre(
621           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 4)),
622                 Field(&AnnotatedSpan::classification,
623                       ElementsAre(AllOf(
624                           Field(&ClassificationResult::collection, "duration"),
625                           Field(&ClassificationResult::duration_ms,
626                                 10 * 60 * 1000)))))));
627 }
628 
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithHalfExpression)629 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithHalfExpression) {
630   const UnicodeText text = UTF8ToUnicodeText("2 分 半 の アラーム");
631   std::vector<Token> tokens = Tokenize(text);
632   std::vector<AnnotatedSpan> result;
633   EXPECT_TRUE(duration_annotator_.FindAll(
634       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
635       ModeFlag_ANNOTATION, &result));
636 
637   EXPECT_THAT(
638       result,
639       ElementsAre(
640           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 5)),
641                 Field(&AnnotatedSpan::classification,
642                       ElementsAre(AllOf(
643                           Field(&ClassificationResult::collection, "duration"),
644                           Field(&ClassificationResult::duration_ms,
645                                 2.5 * 60 * 1000)))))));
646 }
647 
TEST_F(JapaneseDurationAnnotatorTest,IgnoresDurationWithoutQuantity)648 TEST_F(JapaneseDurationAnnotatorTest, IgnoresDurationWithoutQuantity) {
649   const UnicodeText text = UTF8ToUnicodeText("分 の アラーム");
650   std::vector<Token> tokens = Tokenize(text);
651   std::vector<AnnotatedSpan> result;
652   EXPECT_TRUE(duration_annotator_.FindAll(
653       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
654       ModeFlag_ANNOTATION, &result));
655 
656   EXPECT_THAT(result, IsEmpty());
657 }
658 
TEST_F(JapaneseDurationAnnotatorTest,FindsDurationWithDanglingQuantity)659 TEST_F(JapaneseDurationAnnotatorTest, FindsDurationWithDanglingQuantity) {
660   const UnicodeText text = UTF8ToUnicodeText("2 分 10 の アラーム");
661   std::vector<Token> tokens = Tokenize(text);
662   std::vector<AnnotatedSpan> result;
663   EXPECT_TRUE(duration_annotator_.FindAll(
664       text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW,
665       ModeFlag_SELECTION, &result));
666 
667   EXPECT_THAT(
668       result,
669       ElementsAre(
670           AllOf(Field(&AnnotatedSpan::span, CodepointSpan(0, 6)),
671                 Field(&AnnotatedSpan::classification,
672                       ElementsAre(AllOf(
673                           Field(&ClassificationResult::collection, "duration"),
674                           Field(&ClassificationResult::duration_ms,
675                                 2 * 60 * 1000 + 10 * 1000)))))));
676 }
677 
678 }  // namespace
679 }  // namespace libtextclassifier3
680