xref: /aosp_15_r20/frameworks/native/libs/input/TfLiteMotionPredictor.cpp (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1*38e8c45fSAndroid Build Coastguard Worker /*
2*38e8c45fSAndroid Build Coastguard Worker  * Copyright (C) 2023 The Android Open Source Project
3*38e8c45fSAndroid Build Coastguard Worker  *
4*38e8c45fSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*38e8c45fSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*38e8c45fSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*38e8c45fSAndroid Build Coastguard Worker  *
8*38e8c45fSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*38e8c45fSAndroid Build Coastguard Worker  *
10*38e8c45fSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*38e8c45fSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*38e8c45fSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*38e8c45fSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*38e8c45fSAndroid Build Coastguard Worker  * limitations under the License.
15*38e8c45fSAndroid Build Coastguard Worker  */
16*38e8c45fSAndroid Build Coastguard Worker 
17*38e8c45fSAndroid Build Coastguard Worker #define LOG_TAG "TfLiteMotionPredictor"
18*38e8c45fSAndroid Build Coastguard Worker #include <input/TfLiteMotionPredictor.h>
19*38e8c45fSAndroid Build Coastguard Worker 
20*38e8c45fSAndroid Build Coastguard Worker #include <fcntl.h>
21*38e8c45fSAndroid Build Coastguard Worker #include <sys/mman.h>
22*38e8c45fSAndroid Build Coastguard Worker #include <unistd.h>
23*38e8c45fSAndroid Build Coastguard Worker 
24*38e8c45fSAndroid Build Coastguard Worker #include <algorithm>
25*38e8c45fSAndroid Build Coastguard Worker #include <cmath>
26*38e8c45fSAndroid Build Coastguard Worker #include <cstddef>
27*38e8c45fSAndroid Build Coastguard Worker #include <cstdint>
28*38e8c45fSAndroid Build Coastguard Worker #include <memory>
29*38e8c45fSAndroid Build Coastguard Worker #include <span>
30*38e8c45fSAndroid Build Coastguard Worker #include <type_traits>
31*38e8c45fSAndroid Build Coastguard Worker #include <utility>
32*38e8c45fSAndroid Build Coastguard Worker 
33*38e8c45fSAndroid Build Coastguard Worker #include <android-base/file.h>
34*38e8c45fSAndroid Build Coastguard Worker #include <android-base/logging.h>
35*38e8c45fSAndroid Build Coastguard Worker #include <android-base/mapped_file.h>
36*38e8c45fSAndroid Build Coastguard Worker #define ATRACE_TAG ATRACE_TAG_INPUT
37*38e8c45fSAndroid Build Coastguard Worker #include <cutils/trace.h>
38*38e8c45fSAndroid Build Coastguard Worker #include <log/log.h>
39*38e8c45fSAndroid Build Coastguard Worker #include <utils/Timers.h>
40*38e8c45fSAndroid Build Coastguard Worker 
41*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/core/api/error_reporter.h"
42*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/core/api/op_resolver.h"
43*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
44*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/builtin_op_kernels.h"
45*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
46*38e8c45fSAndroid Build Coastguard Worker #include "tensorflow/lite/mutable_op_resolver.h"
47*38e8c45fSAndroid Build Coastguard Worker 
48*38e8c45fSAndroid Build Coastguard Worker #include "tinyxml2.h"
49*38e8c45fSAndroid Build Coastguard Worker 
50*38e8c45fSAndroid Build Coastguard Worker namespace android {
51*38e8c45fSAndroid Build Coastguard Worker namespace {
52*38e8c45fSAndroid Build Coastguard Worker 
53*38e8c45fSAndroid Build Coastguard Worker constexpr char SIGNATURE_KEY[] = "serving_default";
54*38e8c45fSAndroid Build Coastguard Worker 
55*38e8c45fSAndroid Build Coastguard Worker // Input tensor names.
56*38e8c45fSAndroid Build Coastguard Worker constexpr char INPUT_R[] = "r";
57*38e8c45fSAndroid Build Coastguard Worker constexpr char INPUT_PHI[] = "phi";
58*38e8c45fSAndroid Build Coastguard Worker constexpr char INPUT_PRESSURE[] = "pressure";
59*38e8c45fSAndroid Build Coastguard Worker constexpr char INPUT_TILT[] = "tilt";
60*38e8c45fSAndroid Build Coastguard Worker constexpr char INPUT_ORIENTATION[] = "orientation";
61*38e8c45fSAndroid Build Coastguard Worker 
62*38e8c45fSAndroid Build Coastguard Worker // Output tensor names.
63*38e8c45fSAndroid Build Coastguard Worker constexpr char OUTPUT_R[] = "r";
64*38e8c45fSAndroid Build Coastguard Worker constexpr char OUTPUT_PHI[] = "phi";
65*38e8c45fSAndroid Build Coastguard Worker constexpr char OUTPUT_PRESSURE[] = "pressure";
66*38e8c45fSAndroid Build Coastguard Worker 
67*38e8c45fSAndroid Build Coastguard Worker // Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes
68*38e8c45fSAndroid Build Coastguard Worker // build issues in other parts of the system.
69*38e8c45fSAndroid Build Coastguard Worker #if defined(__ANDROID__)
fileExists(const char * filename)70*38e8c45fSAndroid Build Coastguard Worker bool fileExists(const char* filename) {
71*38e8c45fSAndroid Build Coastguard Worker     struct stat buffer;
72*38e8c45fSAndroid Build Coastguard Worker     return stat(filename, &buffer) == 0;
73*38e8c45fSAndroid Build Coastguard Worker }
74*38e8c45fSAndroid Build Coastguard Worker #endif
75*38e8c45fSAndroid Build Coastguard Worker 
getModelPath()76*38e8c45fSAndroid Build Coastguard Worker std::string getModelPath() {
77*38e8c45fSAndroid Build Coastguard Worker #if defined(__ANDROID__)
78*38e8c45fSAndroid Build Coastguard Worker     static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite";
79*38e8c45fSAndroid Build Coastguard Worker     if (fileExists(oemModel)) {
80*38e8c45fSAndroid Build Coastguard Worker         return oemModel;
81*38e8c45fSAndroid Build Coastguard Worker     }
82*38e8c45fSAndroid Build Coastguard Worker     return "/system/etc/motion_predictor_model.tflite";
83*38e8c45fSAndroid Build Coastguard Worker #else
84*38e8c45fSAndroid Build Coastguard Worker     return base::GetExecutableDirectory() + "/motion_predictor_model.tflite";
85*38e8c45fSAndroid Build Coastguard Worker #endif
86*38e8c45fSAndroid Build Coastguard Worker }
87*38e8c45fSAndroid Build Coastguard Worker 
getConfigPath()88*38e8c45fSAndroid Build Coastguard Worker std::string getConfigPath() {
89*38e8c45fSAndroid Build Coastguard Worker     // The config file should be alongside the model file.
90*38e8c45fSAndroid Build Coastguard Worker     return base::Dirname(getModelPath()) + "/motion_predictor_config.xml";
91*38e8c45fSAndroid Build Coastguard Worker }
92*38e8c45fSAndroid Build Coastguard Worker 
parseXMLInt64(const tinyxml2::XMLElement & configRoot,const char * elementName)93*38e8c45fSAndroid Build Coastguard Worker int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) {
94*38e8c45fSAndroid Build Coastguard Worker     const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
95*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
96*38e8c45fSAndroid Build Coastguard Worker 
97*38e8c45fSAndroid Build Coastguard Worker     int64_t value = 0;
98*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS,
99*38e8c45fSAndroid Build Coastguard Worker                         "Failed to parse %s: %s", elementName, element->GetText());
100*38e8c45fSAndroid Build Coastguard Worker     return value;
101*38e8c45fSAndroid Build Coastguard Worker }
102*38e8c45fSAndroid Build Coastguard Worker 
parseXMLFloat(const tinyxml2::XMLElement & configRoot,const char * elementName)103*38e8c45fSAndroid Build Coastguard Worker float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
104*38e8c45fSAndroid Build Coastguard Worker     const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
105*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
106*38e8c45fSAndroid Build Coastguard Worker 
107*38e8c45fSAndroid Build Coastguard Worker     float value = 0;
108*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
109*38e8c45fSAndroid Build Coastguard Worker                         "Failed to parse %s: %s", elementName, element->GetText());
110*38e8c45fSAndroid Build Coastguard Worker     return value;
111*38e8c45fSAndroid Build Coastguard Worker }
112*38e8c45fSAndroid Build Coastguard Worker 
113*38e8c45fSAndroid Build Coastguard Worker // A TFLite ErrorReporter that logs to logcat.
114*38e8c45fSAndroid Build Coastguard Worker class LoggingErrorReporter : public tflite::ErrorReporter {
115*38e8c45fSAndroid Build Coastguard Worker public:
Report(const char * format,va_list args)116*38e8c45fSAndroid Build Coastguard Worker     int Report(const char* format, va_list args) override {
117*38e8c45fSAndroid Build Coastguard Worker         return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
118*38e8c45fSAndroid Build Coastguard Worker     }
119*38e8c45fSAndroid Build Coastguard Worker };
120*38e8c45fSAndroid Build Coastguard Worker 
121*38e8c45fSAndroid Build Coastguard Worker // Searches a runner for an input tensor.
findInputTensor(const char * name,tflite::SignatureRunner * runner)122*38e8c45fSAndroid Build Coastguard Worker TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
123*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* tensor = runner->input_tensor(name);
124*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
125*38e8c45fSAndroid Build Coastguard Worker     return tensor;
126*38e8c45fSAndroid Build Coastguard Worker }
127*38e8c45fSAndroid Build Coastguard Worker 
128*38e8c45fSAndroid Build Coastguard Worker // Searches a runner for an output tensor.
findOutputTensor(const char * name,tflite::SignatureRunner * runner)129*38e8c45fSAndroid Build Coastguard Worker const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
130*38e8c45fSAndroid Build Coastguard Worker     const TfLiteTensor* tensor = runner->output_tensor(name);
131*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
132*38e8c45fSAndroid Build Coastguard Worker     return tensor;
133*38e8c45fSAndroid Build Coastguard Worker }
134*38e8c45fSAndroid Build Coastguard Worker 
135*38e8c45fSAndroid Build Coastguard Worker // Returns the buffer for a tensor of type T.
136*38e8c45fSAndroid Build Coastguard Worker template <typename T>
getTensorBuffer(typename std::conditional<std::is_const<T>::value,const TfLiteTensor *,TfLiteTensor * >::type tensor)137*38e8c45fSAndroid Build Coastguard Worker std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
138*38e8c45fSAndroid Build Coastguard Worker                                                        TfLiteTensor*>::type tensor) {
139*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!tensor);
140*38e8c45fSAndroid Build Coastguard Worker 
141*38e8c45fSAndroid Build Coastguard Worker     const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
142*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
143*38e8c45fSAndroid Build Coastguard Worker                         tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
144*38e8c45fSAndroid Build Coastguard Worker 
145*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!tensor->data.data);
146*38e8c45fSAndroid Build Coastguard Worker     return std::span<T>(reinterpret_cast<T*>(tensor->data.data), tensor->bytes / sizeof(T));
147*38e8c45fSAndroid Build Coastguard Worker }
148*38e8c45fSAndroid Build Coastguard Worker 
149*38e8c45fSAndroid Build Coastguard Worker // Verifies that a tensor exists and has an underlying buffer of type T.
150*38e8c45fSAndroid Build Coastguard Worker template <typename T>
checkTensor(const TfLiteTensor * tensor)151*38e8c45fSAndroid Build Coastguard Worker void checkTensor(const TfLiteTensor* tensor) {
152*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!tensor);
153*38e8c45fSAndroid Build Coastguard Worker 
154*38e8c45fSAndroid Build Coastguard Worker     const auto buffer = getTensorBuffer<const T>(tensor);
155*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
156*38e8c45fSAndroid Build Coastguard Worker }
157*38e8c45fSAndroid Build Coastguard Worker 
createOpResolver()158*38e8c45fSAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> createOpResolver() {
159*38e8c45fSAndroid Build Coastguard Worker     auto resolver = std::make_unique<tflite::MutableOpResolver>();
160*38e8c45fSAndroid Build Coastguard Worker     resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
161*38e8c45fSAndroid Build Coastguard Worker                          ::tflite::ops::builtin::Register_CONCATENATION());
162*38e8c45fSAndroid Build Coastguard Worker     resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
163*38e8c45fSAndroid Build Coastguard Worker                          ::tflite::ops::builtin::Register_FULLY_CONNECTED());
164*38e8c45fSAndroid Build Coastguard Worker     resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
165*38e8c45fSAndroid Build Coastguard Worker     return resolver;
166*38e8c45fSAndroid Build Coastguard Worker }
167*38e8c45fSAndroid Build Coastguard Worker 
168*38e8c45fSAndroid Build Coastguard Worker } // namespace
169*38e8c45fSAndroid Build Coastguard Worker 
TfLiteMotionPredictorBuffers(size_t inputLength)170*38e8c45fSAndroid Build Coastguard Worker TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
171*38e8c45fSAndroid Build Coastguard Worker       : mInputR(inputLength, 0),
172*38e8c45fSAndroid Build Coastguard Worker         mInputPhi(inputLength, 0),
173*38e8c45fSAndroid Build Coastguard Worker         mInputPressure(inputLength, 0),
174*38e8c45fSAndroid Build Coastguard Worker         mInputTilt(inputLength, 0),
175*38e8c45fSAndroid Build Coastguard Worker         mInputOrientation(inputLength, 0) {
176*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
177*38e8c45fSAndroid Build Coastguard Worker }
178*38e8c45fSAndroid Build Coastguard Worker 
reset()179*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorBuffers::reset() {
180*38e8c45fSAndroid Build Coastguard Worker     std::fill(mInputR.begin(), mInputR.end(), 0);
181*38e8c45fSAndroid Build Coastguard Worker     std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
182*38e8c45fSAndroid Build Coastguard Worker     std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
183*38e8c45fSAndroid Build Coastguard Worker     std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
184*38e8c45fSAndroid Build Coastguard Worker     std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
185*38e8c45fSAndroid Build Coastguard Worker     mAxisFrom.reset();
186*38e8c45fSAndroid Build Coastguard Worker     mAxisTo.reset();
187*38e8c45fSAndroid Build Coastguard Worker }
188*38e8c45fSAndroid Build Coastguard Worker 
copyTo(TfLiteMotionPredictorModel & model) const189*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
190*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
191*38e8c45fSAndroid Build Coastguard Worker                         "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
192*38e8c45fSAndroid Build Coastguard Worker                         model.inputLength());
193*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
194*38e8c45fSAndroid Build Coastguard Worker 
195*38e8c45fSAndroid Build Coastguard Worker     std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
196*38e8c45fSAndroid Build Coastguard Worker     std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
197*38e8c45fSAndroid Build Coastguard Worker     std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
198*38e8c45fSAndroid Build Coastguard Worker     std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
199*38e8c45fSAndroid Build Coastguard Worker     std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
200*38e8c45fSAndroid Build Coastguard Worker }
201*38e8c45fSAndroid Build Coastguard Worker 
pushSample(int64_t timestamp,const TfLiteMotionPredictorSample sample)202*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
203*38e8c45fSAndroid Build Coastguard Worker                                               const TfLiteMotionPredictorSample sample) {
204*38e8c45fSAndroid Build Coastguard Worker     // Convert the sample (x, y) into polar (r, φ) based on a reference axis
205*38e8c45fSAndroid Build Coastguard Worker     // from the preceding two points (mAxisFrom/mAxisTo).
206*38e8c45fSAndroid Build Coastguard Worker 
207*38e8c45fSAndroid Build Coastguard Worker     mTimestamp = timestamp;
208*38e8c45fSAndroid Build Coastguard Worker 
209*38e8c45fSAndroid Build Coastguard Worker     if (!mAxisTo) { // First point.
210*38e8c45fSAndroid Build Coastguard Worker         mAxisTo = sample;
211*38e8c45fSAndroid Build Coastguard Worker         return;
212*38e8c45fSAndroid Build Coastguard Worker     }
213*38e8c45fSAndroid Build Coastguard Worker 
214*38e8c45fSAndroid Build Coastguard Worker     // Vector from the last point to the current sample point.
215*38e8c45fSAndroid Build Coastguard Worker     const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
216*38e8c45fSAndroid Build Coastguard Worker 
217*38e8c45fSAndroid Build Coastguard Worker     const float r = std::hypot(v.x, v.y);
218*38e8c45fSAndroid Build Coastguard Worker     float phi = 0;
219*38e8c45fSAndroid Build Coastguard Worker     float orientation = 0;
220*38e8c45fSAndroid Build Coastguard Worker 
221*38e8c45fSAndroid Build Coastguard Worker     if (!mAxisFrom && r > 0) { // Second point.
222*38e8c45fSAndroid Build Coastguard Worker         // We can only determine the distance from the first point, and not any
223*38e8c45fSAndroid Build Coastguard Worker         // angle. However, if the second point forms an axis, the orientation can
224*38e8c45fSAndroid Build Coastguard Worker         // be transformed relative to that axis.
225*38e8c45fSAndroid Build Coastguard Worker         const float axisPhi = std::atan2(v.y, v.x);
226*38e8c45fSAndroid Build Coastguard Worker         // A MotionEvent's orientation is measured clockwise from the vertical
227*38e8c45fSAndroid Build Coastguard Worker         // axis, but axisPhi is measured counter-clockwise from the horizontal
228*38e8c45fSAndroid Build Coastguard Worker         // axis.
229*38e8c45fSAndroid Build Coastguard Worker         orientation = M_PI_2 - sample.orientation - axisPhi;
230*38e8c45fSAndroid Build Coastguard Worker     } else {
231*38e8c45fSAndroid Build Coastguard Worker         const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
232*38e8c45fSAndroid Build Coastguard Worker         const float axisPhi = std::atan2(axis.y, axis.x);
233*38e8c45fSAndroid Build Coastguard Worker         phi = std::atan2(v.y, v.x) - axisPhi;
234*38e8c45fSAndroid Build Coastguard Worker 
235*38e8c45fSAndroid Build Coastguard Worker         if (std::hypot(axis.x, axis.y) > 0) {
236*38e8c45fSAndroid Build Coastguard Worker             // See note above.
237*38e8c45fSAndroid Build Coastguard Worker             orientation = M_PI_2 - sample.orientation - axisPhi;
238*38e8c45fSAndroid Build Coastguard Worker         }
239*38e8c45fSAndroid Build Coastguard Worker     }
240*38e8c45fSAndroid Build Coastguard Worker 
241*38e8c45fSAndroid Build Coastguard Worker     // Update the axis for the next point.
242*38e8c45fSAndroid Build Coastguard Worker     if (r > 0) {
243*38e8c45fSAndroid Build Coastguard Worker         mAxisFrom = mAxisTo;
244*38e8c45fSAndroid Build Coastguard Worker         mAxisTo = sample;
245*38e8c45fSAndroid Build Coastguard Worker     }
246*38e8c45fSAndroid Build Coastguard Worker 
247*38e8c45fSAndroid Build Coastguard Worker     // Push the current sample onto the end of the input buffers.
248*38e8c45fSAndroid Build Coastguard Worker     mInputR.pushBack(r);
249*38e8c45fSAndroid Build Coastguard Worker     mInputPhi.pushBack(phi);
250*38e8c45fSAndroid Build Coastguard Worker     mInputPressure.pushBack(sample.pressure);
251*38e8c45fSAndroid Build Coastguard Worker     mInputTilt.pushBack(sample.tilt);
252*38e8c45fSAndroid Build Coastguard Worker     mInputOrientation.pushBack(orientation);
253*38e8c45fSAndroid Build Coastguard Worker }
254*38e8c45fSAndroid Build Coastguard Worker 
create()255*38e8c45fSAndroid Build Coastguard Worker std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
256*38e8c45fSAndroid Build Coastguard Worker     const std::string modelPath = getModelPath();
257*38e8c45fSAndroid Build Coastguard Worker     android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY));
258*38e8c45fSAndroid Build Coastguard Worker     if (fd == -1) {
259*38e8c45fSAndroid Build Coastguard Worker         PLOG(FATAL) << "Could not read model from " << modelPath;
260*38e8c45fSAndroid Build Coastguard Worker     }
261*38e8c45fSAndroid Build Coastguard Worker 
262*38e8c45fSAndroid Build Coastguard Worker     const off_t fdSize = lseek(fd, 0, SEEK_END);
263*38e8c45fSAndroid Build Coastguard Worker     if (fdSize == -1) {
264*38e8c45fSAndroid Build Coastguard Worker         PLOG(FATAL) << "Failed to determine file size";
265*38e8c45fSAndroid Build Coastguard Worker     }
266*38e8c45fSAndroid Build Coastguard Worker 
267*38e8c45fSAndroid Build Coastguard Worker     std::unique_ptr<android::base::MappedFile> modelBuffer =
268*38e8c45fSAndroid Build Coastguard Worker             android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
269*38e8c45fSAndroid Build Coastguard Worker     if (!modelBuffer) {
270*38e8c45fSAndroid Build Coastguard Worker         PLOG(FATAL) << "Failed to mmap model";
271*38e8c45fSAndroid Build Coastguard Worker     }
272*38e8c45fSAndroid Build Coastguard Worker 
273*38e8c45fSAndroid Build Coastguard Worker     const std::string configPath = getConfigPath();
274*38e8c45fSAndroid Build Coastguard Worker     tinyxml2::XMLDocument configDocument;
275*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS,
276*38e8c45fSAndroid Build Coastguard Worker                         "Failed to load config file from %s", configPath.c_str());
277*38e8c45fSAndroid Build Coastguard Worker 
278*38e8c45fSAndroid Build Coastguard Worker     // Parse configuration file.
279*38e8c45fSAndroid Build Coastguard Worker     const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
280*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!configRoot);
281*38e8c45fSAndroid Build Coastguard Worker     Config config{
282*38e8c45fSAndroid Build Coastguard Worker             .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
283*38e8c45fSAndroid Build Coastguard Worker             .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
284*38e8c45fSAndroid Build Coastguard Worker             .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
285*38e8c45fSAndroid Build Coastguard Worker             .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
286*38e8c45fSAndroid Build Coastguard Worker             .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"),
287*38e8c45fSAndroid Build Coastguard Worker     };
288*38e8c45fSAndroid Build Coastguard Worker 
289*38e8c45fSAndroid Build Coastguard Worker     return std::unique_ptr<TfLiteMotionPredictorModel>(
290*38e8c45fSAndroid Build Coastguard Worker             new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
291*38e8c45fSAndroid Build Coastguard Worker }
292*38e8c45fSAndroid Build Coastguard Worker 
TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,Config config)293*38e8c45fSAndroid Build Coastguard Worker TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
294*38e8c45fSAndroid Build Coastguard Worker         std::unique_ptr<android::base::MappedFile> model, Config config)
295*38e8c45fSAndroid Build Coastguard Worker       : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
296*38e8c45fSAndroid Build Coastguard Worker     CHECK(mFlatBuffer);
297*38e8c45fSAndroid Build Coastguard Worker     mErrorReporter = std::make_unique<LoggingErrorReporter>();
298*38e8c45fSAndroid Build Coastguard Worker     mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
299*38e8c45fSAndroid Build Coastguard Worker                                                                mFlatBuffer->size(),
300*38e8c45fSAndroid Build Coastguard Worker                                                                /*extra_verifier=*/nullptr,
301*38e8c45fSAndroid Build Coastguard Worker                                                                mErrorReporter.get());
302*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!mModel);
303*38e8c45fSAndroid Build Coastguard Worker 
304*38e8c45fSAndroid Build Coastguard Worker     auto resolver = createOpResolver();
305*38e8c45fSAndroid Build Coastguard Worker     tflite::InterpreterBuilder builder(*mModel, *resolver);
306*38e8c45fSAndroid Build Coastguard Worker 
307*38e8c45fSAndroid Build Coastguard Worker     if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
308*38e8c45fSAndroid Build Coastguard Worker         LOG_ALWAYS_FATAL("Failed to build interpreter");
309*38e8c45fSAndroid Build Coastguard Worker     }
310*38e8c45fSAndroid Build Coastguard Worker 
311*38e8c45fSAndroid Build Coastguard Worker     mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
312*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
313*38e8c45fSAndroid Build Coastguard Worker 
314*38e8c45fSAndroid Build Coastguard Worker     allocateTensors();
315*38e8c45fSAndroid Build Coastguard Worker }
316*38e8c45fSAndroid Build Coastguard Worker 
~TfLiteMotionPredictorModel()317*38e8c45fSAndroid Build Coastguard Worker TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
318*38e8c45fSAndroid Build Coastguard Worker 
allocateTensors()319*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorModel::allocateTensors() {
320*38e8c45fSAndroid Build Coastguard Worker     if (mRunner->AllocateTensors() != kTfLiteOk) {
321*38e8c45fSAndroid Build Coastguard Worker         LOG_ALWAYS_FATAL("Failed to allocate tensors");
322*38e8c45fSAndroid Build Coastguard Worker     }
323*38e8c45fSAndroid Build Coastguard Worker 
324*38e8c45fSAndroid Build Coastguard Worker     attachInputTensors();
325*38e8c45fSAndroid Build Coastguard Worker     attachOutputTensors();
326*38e8c45fSAndroid Build Coastguard Worker 
327*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mInputR);
328*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mInputPhi);
329*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mInputPressure);
330*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mInputTilt);
331*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mInputOrientation);
332*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mOutputR);
333*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mOutputPhi);
334*38e8c45fSAndroid Build Coastguard Worker     checkTensor<float>(mOutputPressure);
335*38e8c45fSAndroid Build Coastguard Worker 
336*38e8c45fSAndroid Build Coastguard Worker     const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
337*38e8c45fSAndroid Build Coastguard Worker         const size_t size = getTensorBuffer<const float>(tensor).size();
338*38e8c45fSAndroid Build Coastguard Worker         LOG_ALWAYS_FATAL_IF(size != inputLength(),
339*38e8c45fSAndroid Build Coastguard Worker                             "Tensor '%s' length %zu does not match input length %zu", tensor->name,
340*38e8c45fSAndroid Build Coastguard Worker                             size, inputLength());
341*38e8c45fSAndroid Build Coastguard Worker     };
342*38e8c45fSAndroid Build Coastguard Worker 
343*38e8c45fSAndroid Build Coastguard Worker     checkInputTensorSize(mInputR);
344*38e8c45fSAndroid Build Coastguard Worker     checkInputTensorSize(mInputPhi);
345*38e8c45fSAndroid Build Coastguard Worker     checkInputTensorSize(mInputPressure);
346*38e8c45fSAndroid Build Coastguard Worker     checkInputTensorSize(mInputTilt);
347*38e8c45fSAndroid Build Coastguard Worker     checkInputTensorSize(mInputOrientation);
348*38e8c45fSAndroid Build Coastguard Worker }
349*38e8c45fSAndroid Build Coastguard Worker 
attachInputTensors()350*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorModel::attachInputTensors() {
351*38e8c45fSAndroid Build Coastguard Worker     mInputR = findInputTensor(INPUT_R, mRunner);
352*38e8c45fSAndroid Build Coastguard Worker     mInputPhi = findInputTensor(INPUT_PHI, mRunner);
353*38e8c45fSAndroid Build Coastguard Worker     mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
354*38e8c45fSAndroid Build Coastguard Worker     mInputTilt = findInputTensor(INPUT_TILT, mRunner);
355*38e8c45fSAndroid Build Coastguard Worker     mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
356*38e8c45fSAndroid Build Coastguard Worker }
357*38e8c45fSAndroid Build Coastguard Worker 
attachOutputTensors()358*38e8c45fSAndroid Build Coastguard Worker void TfLiteMotionPredictorModel::attachOutputTensors() {
359*38e8c45fSAndroid Build Coastguard Worker     mOutputR = findOutputTensor(OUTPUT_R, mRunner);
360*38e8c45fSAndroid Build Coastguard Worker     mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
361*38e8c45fSAndroid Build Coastguard Worker     mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
362*38e8c45fSAndroid Build Coastguard Worker }
363*38e8c45fSAndroid Build Coastguard Worker 
invoke()364*38e8c45fSAndroid Build Coastguard Worker bool TfLiteMotionPredictorModel::invoke() {
365*38e8c45fSAndroid Build Coastguard Worker     ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
366*38e8c45fSAndroid Build Coastguard Worker     TfLiteStatus result = mRunner->Invoke();
367*38e8c45fSAndroid Build Coastguard Worker     ATRACE_END();
368*38e8c45fSAndroid Build Coastguard Worker 
369*38e8c45fSAndroid Build Coastguard Worker     if (result != kTfLiteOk) {
370*38e8c45fSAndroid Build Coastguard Worker         return false;
371*38e8c45fSAndroid Build Coastguard Worker     }
372*38e8c45fSAndroid Build Coastguard Worker 
373*38e8c45fSAndroid Build Coastguard Worker     // Invoke() might reallocate tensors, so they need to be reattached.
374*38e8c45fSAndroid Build Coastguard Worker     attachInputTensors();
375*38e8c45fSAndroid Build Coastguard Worker     attachOutputTensors();
376*38e8c45fSAndroid Build Coastguard Worker 
377*38e8c45fSAndroid Build Coastguard Worker     if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
378*38e8c45fSAndroid Build Coastguard Worker         LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
379*38e8c45fSAndroid Build Coastguard Worker                          outputR().size(), outputPhi().size(), outputPressure().size());
380*38e8c45fSAndroid Build Coastguard Worker     }
381*38e8c45fSAndroid Build Coastguard Worker 
382*38e8c45fSAndroid Build Coastguard Worker     return true;
383*38e8c45fSAndroid Build Coastguard Worker }
384*38e8c45fSAndroid Build Coastguard Worker 
inputLength() const385*38e8c45fSAndroid Build Coastguard Worker size_t TfLiteMotionPredictorModel::inputLength() const {
386*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<const float>(mInputR).size();
387*38e8c45fSAndroid Build Coastguard Worker }
388*38e8c45fSAndroid Build Coastguard Worker 
outputLength() const389*38e8c45fSAndroid Build Coastguard Worker size_t TfLiteMotionPredictorModel::outputLength() const {
390*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<const float>(mOutputR).size();
391*38e8c45fSAndroid Build Coastguard Worker }
392*38e8c45fSAndroid Build Coastguard Worker 
inputR()393*38e8c45fSAndroid Build Coastguard Worker std::span<float> TfLiteMotionPredictorModel::inputR() {
394*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<float>(mInputR);
395*38e8c45fSAndroid Build Coastguard Worker }
396*38e8c45fSAndroid Build Coastguard Worker 
inputPhi()397*38e8c45fSAndroid Build Coastguard Worker std::span<float> TfLiteMotionPredictorModel::inputPhi() {
398*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<float>(mInputPhi);
399*38e8c45fSAndroid Build Coastguard Worker }
400*38e8c45fSAndroid Build Coastguard Worker 
inputPressure()401*38e8c45fSAndroid Build Coastguard Worker std::span<float> TfLiteMotionPredictorModel::inputPressure() {
402*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<float>(mInputPressure);
403*38e8c45fSAndroid Build Coastguard Worker }
404*38e8c45fSAndroid Build Coastguard Worker 
inputTilt()405*38e8c45fSAndroid Build Coastguard Worker std::span<float> TfLiteMotionPredictorModel::inputTilt() {
406*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<float>(mInputTilt);
407*38e8c45fSAndroid Build Coastguard Worker }
408*38e8c45fSAndroid Build Coastguard Worker 
inputOrientation()409*38e8c45fSAndroid Build Coastguard Worker std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
410*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<float>(mInputOrientation);
411*38e8c45fSAndroid Build Coastguard Worker }
412*38e8c45fSAndroid Build Coastguard Worker 
outputR() const413*38e8c45fSAndroid Build Coastguard Worker std::span<const float> TfLiteMotionPredictorModel::outputR() const {
414*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<const float>(mOutputR);
415*38e8c45fSAndroid Build Coastguard Worker }
416*38e8c45fSAndroid Build Coastguard Worker 
outputPhi() const417*38e8c45fSAndroid Build Coastguard Worker std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
418*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<const float>(mOutputPhi);
419*38e8c45fSAndroid Build Coastguard Worker }
420*38e8c45fSAndroid Build Coastguard Worker 
outputPressure() const421*38e8c45fSAndroid Build Coastguard Worker std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
422*38e8c45fSAndroid Build Coastguard Worker     return getTensorBuffer<const float>(mOutputPressure);
423*38e8c45fSAndroid Build Coastguard Worker }
424*38e8c45fSAndroid Build Coastguard Worker 
425*38e8c45fSAndroid Build Coastguard Worker } // namespace android
426