1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 // Utility functions for working with FlatBuffers.
18
19 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
20 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
21
22 #include <iostream>
23 #include <string>
24
25 #include "annotator/model_generated.h"
26 #include "flatbuffers/flatbuffers.h"
27
28 namespace libtextclassifier3 {
29
30 // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
31 // integrity.
32 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const void * buffer,int size)33 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const void* buffer, int size) {
34 if (size == 0) {
35 return nullptr;
36 }
37 const FlatbufferMessage* message =
38 flatbuffers::GetRoot<FlatbufferMessage>(buffer);
39 if (message == nullptr) {
40 return nullptr;
41 }
42
43 flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(buffer),
44 size);
45 if (message->Verify(verifier)) {
46 return message;
47 } else {
48 // TODO(217577534): Need to figure out why the verifier is failing.
49 return message;
50 }
51 }
52
53 // Same as above but takes string.
54 template <typename FlatbufferMessage>
LoadAndVerifyFlatbuffer(const std::string & buffer)55 const FlatbufferMessage* LoadAndVerifyFlatbuffer(const std::string& buffer) {
56 return LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer.c_str(),
57 buffer.size());
58 }
59
60 // Loads and interprets the buffer as 'FlatbufferMessage', verifies its
61 // integrity and returns its mutable version.
62 template <typename FlatbufferMessage>
63 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const void * buffer,int size)64 LoadAndVerifyMutableFlatbuffer(const void* buffer, int size) {
65 const FlatbufferMessage* message =
66 LoadAndVerifyFlatbuffer<FlatbufferMessage>(buffer, size);
67 if (message == nullptr) {
68 return nullptr;
69 }
70 return std::unique_ptr<typename FlatbufferMessage::NativeTableType>(
71 message->UnPack());
72 }
73
74 // Same as above but takes string.
75 template <typename FlatbufferMessage>
76 std::unique_ptr<typename FlatbufferMessage::NativeTableType>
LoadAndVerifyMutableFlatbuffer(const std::string & buffer)77 LoadAndVerifyMutableFlatbuffer(const std::string& buffer) {
78 return LoadAndVerifyMutableFlatbuffer<FlatbufferMessage>(buffer.c_str(),
79 buffer.size());
80 }
81
82 template <typename FlatbufferMessage>
FlatbufferFileIdentifier()83 const char* FlatbufferFileIdentifier() {
84 return nullptr;
85 }
86
87 template <>
88 inline const char* FlatbufferFileIdentifier<Model>() {
89 return ModelIdentifier();
90 }
91
92 // Packs the mutable flatbuffer message to string.
93 template <typename FlatbufferMessage>
PackFlatbuffer(const typename FlatbufferMessage::NativeTableType * mutable_message)94 std::string PackFlatbuffer(
95 const typename FlatbufferMessage::NativeTableType* mutable_message) {
96 flatbuffers::FlatBufferBuilder builder;
97 builder.Finish(FlatbufferMessage::Pack(builder, mutable_message),
98 FlatbufferFileIdentifier<FlatbufferMessage>());
99 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
100 builder.GetSize());
101 }
102
103 // A convenience flatbuffer object with its underlying buffer.
104 template <typename T, typename B = flatbuffers::DetachedBuffer>
105 class OwnedFlatbuffer {
106 public:
OwnedFlatbuffer(B && buffer)107 explicit OwnedFlatbuffer(B&& buffer) : buffer_(std::move(buffer)) {}
108
109 // Cast as flatbuffer type.
get()110 const T* get() const { return flatbuffers::GetRoot<T>(buffer_.data()); }
111
buffer()112 const B& buffer() const { return buffer_; }
113
114 const T* operator->() const {
115 return flatbuffers::GetRoot<T>(buffer_.data());
116 }
117
118 private:
119 B buffer_;
120 };
121
122 } // namespace libtextclassifier3
123
124 #endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_FLATBUFFERS_H_
125