xref: /aosp_15_r20/external/libtextclassifier/native/lang_id/fb_model/model-provider-from-fb.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lang_id/fb_model/model-provider-from-fb.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "lang_id/common/file/file-utils.h"
24 #include "lang_id/common/file/mmap.h"
25 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
26 #include "lang_id/common/flatbuffers/model-utils.h"
27 #include "lang_id/common/lite_strings/str-split.h"
28 
29 namespace libtextclassifier3 {
30 namespace mobile {
31 namespace lang_id {
32 
ModelProviderFromFlatbuffer(const std::string & filename)33 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
34     const std::string &filename)
35 
36     // Using mmap as a fast way to read the model bytes.  As the file is
37     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
38     // stay alive for the entire lifetime of this object.
39     : scoped_mmap_(new ScopedMmap(filename)) {
40   Initialize(scoped_mmap_->handle().to_stringpiece());
41 }
42 
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd)43 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
44     FileDescriptorOrHandle fd)
45 
46     // Using mmap as a fast way to read the model bytes.  As the file is
47     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
48     // stay alive for the entire lifetime of this object.
49     : scoped_mmap_(new ScopedMmap(fd)) {
50   Initialize(scoped_mmap_->handle().to_stringpiece());
51 }
52 
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd,std::size_t offset,std::size_t size)53 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
54     FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
55 
56     // Using mmap as a fast way to read the model bytes.  As the file is
57     // unmapped only when the field scoped_mmap_ is destructed, the model bytes
58     // stay alive for the entire lifetime of this object.
59     : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
60   Initialize(scoped_mmap_->handle().to_stringpiece());
61 }
62 
Initialize(StringPiece model_bytes)63 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
64   // Note: valid_ was initialized to false.  In the code below, we set valid_ to
65   // true only if all initialization steps completed successfully.  Otherwise,
66   // we return early, leaving valid_ to its default value false.
67   model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
68   if (model_ == nullptr) {
69     SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
70     return;
71   }
72 
73   // Initialize context_ parameters.
74   if (!saft_fbs::FillParameters(*model_, &context_)) {
75     // FillParameters already performs error logging.
76     return;
77   }
78 
79   // Init languages_.
80   const std::string known_languages_str =
81       context_.Get("supported_languages", "");
82   for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
83     languages_.emplace_back(sp);
84   }
85   if (languages_.empty()) {
86     SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
87     return;
88   }
89 
90   // Init nn_params_.
91   if (!InitNetworkParams()) {
92     // InitNetworkParams already performs error logging.
93     return;
94   }
95 
96   // Everything looks fine.
97   valid_ = true;
98 }
99 
InitNetworkParams()100 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
101   const std::string kInputName = "language-identifier-network";
102   StringPiece bytes =
103       saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
104   if ((bytes.data() == nullptr) || bytes.empty()) {
105     SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
106     return false;
107   }
108   std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
109       new EmbeddingNetworkParamsFromFlatbuffer(bytes));
110   if (!nn_params_from_fb->is_valid()) {
111     SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
112     return false;
113   }
114   nn_params_ = std::move(nn_params_from_fb);
115   return true;
116 }
117 
118 }  // namespace lang_id
119 }  // namespace mobile
120 }  // namespace nlp_saft
121