1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lang_id/fb_model/model-provider-from-fb.h"
18
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "lang_id/common/file/file-utils.h"
24 #include "lang_id/common/file/mmap.h"
25 #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h"
26 #include "lang_id/common/flatbuffers/model-utils.h"
27 #include "lang_id/common/lite_strings/str-split.h"
28
29 namespace libtextclassifier3 {
30 namespace mobile {
31 namespace lang_id {
32
ModelProviderFromFlatbuffer(const std::string & filename)33 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
34 const std::string &filename)
35
36 // Using mmap as a fast way to read the model bytes. As the file is
37 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
38 // stay alive for the entire lifetime of this object.
39 : scoped_mmap_(new ScopedMmap(filename)) {
40 Initialize(scoped_mmap_->handle().to_stringpiece());
41 }
42
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd)43 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
44 FileDescriptorOrHandle fd)
45
46 // Using mmap as a fast way to read the model bytes. As the file is
47 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
48 // stay alive for the entire lifetime of this object.
49 : scoped_mmap_(new ScopedMmap(fd)) {
50 Initialize(scoped_mmap_->handle().to_stringpiece());
51 }
52
ModelProviderFromFlatbuffer(FileDescriptorOrHandle fd,std::size_t offset,std::size_t size)53 ModelProviderFromFlatbuffer::ModelProviderFromFlatbuffer(
54 FileDescriptorOrHandle fd, std::size_t offset, std::size_t size)
55
56 // Using mmap as a fast way to read the model bytes. As the file is
57 // unmapped only when the field scoped_mmap_ is destructed, the model bytes
58 // stay alive for the entire lifetime of this object.
59 : scoped_mmap_(new ScopedMmap(fd, offset, size)) {
60 Initialize(scoped_mmap_->handle().to_stringpiece());
61 }
62
Initialize(StringPiece model_bytes)63 void ModelProviderFromFlatbuffer::Initialize(StringPiece model_bytes) {
64 // Note: valid_ was initialized to false. In the code below, we set valid_ to
65 // true only if all initialization steps completed successfully. Otherwise,
66 // we return early, leaving valid_ to its default value false.
67 model_ = saft_fbs::GetVerifiedModelFromBytes(model_bytes);
68 if (model_ == nullptr) {
69 SAFTM_LOG(ERROR) << "Unable to initialize ModelProviderFromFlatbuffer";
70 return;
71 }
72
73 // Initialize context_ parameters.
74 if (!saft_fbs::FillParameters(*model_, &context_)) {
75 // FillParameters already performs error logging.
76 return;
77 }
78
79 // Init languages_.
80 const std::string known_languages_str =
81 context_.Get("supported_languages", "");
82 for (StringPiece sp : LiteStrSplit(known_languages_str, ',')) {
83 languages_.emplace_back(sp);
84 }
85 if (languages_.empty()) {
86 SAFTM_LOG(ERROR) << "Unable to find list of supported_languages";
87 return;
88 }
89
90 // Init nn_params_.
91 if (!InitNetworkParams()) {
92 // InitNetworkParams already performs error logging.
93 return;
94 }
95
96 // Everything looks fine.
97 valid_ = true;
98 }
99
InitNetworkParams()100 bool ModelProviderFromFlatbuffer::InitNetworkParams() {
101 const std::string kInputName = "language-identifier-network";
102 StringPiece bytes =
103 saft_fbs::GetInputBytes(saft_fbs::GetInputByName(model_, kInputName));
104 if ((bytes.data() == nullptr) || bytes.empty()) {
105 SAFTM_LOG(ERROR) << "Unable to get bytes for model input " << kInputName;
106 return false;
107 }
108 std::unique_ptr<EmbeddingNetworkParamsFromFlatbuffer> nn_params_from_fb(
109 new EmbeddingNetworkParamsFromFlatbuffer(bytes));
110 if (!nn_params_from_fb->is_valid()) {
111 SAFTM_LOG(ERROR) << "EmbeddingNetworkParamsFromFlatbuffer not valid";
112 return false;
113 }
114 nn_params_ = std::move(nn_params_from_fb);
115 return true;
116 }
117
118 } // namespace lang_id
119 } // namespace mobile
120 } // namespace nlp_saft
121