xref: /aosp_15_r20/frameworks/av/media/libheadtracking/PosePredictor.cpp (revision ec779b8e0859a360c3d303172224686826e6e0e1)
1*ec779b8eSAndroid Build Coastguard Worker /*
2*ec779b8eSAndroid Build Coastguard Worker  * Copyright (C) 2023 The Android Open Source Project
3*ec779b8eSAndroid Build Coastguard Worker  *
4*ec779b8eSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*ec779b8eSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*ec779b8eSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*ec779b8eSAndroid Build Coastguard Worker  *
8*ec779b8eSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*ec779b8eSAndroid Build Coastguard Worker  *
10*ec779b8eSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*ec779b8eSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*ec779b8eSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*ec779b8eSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*ec779b8eSAndroid Build Coastguard Worker  * limitations under the License.
15*ec779b8eSAndroid Build Coastguard Worker  */
16*ec779b8eSAndroid Build Coastguard Worker 
17*ec779b8eSAndroid Build Coastguard Worker #include "PosePredictor.h"
18*ec779b8eSAndroid Build Coastguard Worker 
19*ec779b8eSAndroid Build Coastguard Worker namespace android::media {
20*ec779b8eSAndroid Build Coastguard Worker 
21*ec779b8eSAndroid Build Coastguard Worker namespace {
22*ec779b8eSAndroid Build Coastguard Worker #ifdef ENABLE_VERIFICATION
23*ec779b8eSAndroid Build Coastguard Worker constexpr bool kEnableVerification = true;
24*ec779b8eSAndroid Build Coastguard Worker constexpr std::array<int, 3> kLookAheadMs{ 50, 100, 200 };
25*ec779b8eSAndroid Build Coastguard Worker #else
26*ec779b8eSAndroid Build Coastguard Worker constexpr bool kEnableVerification = false;
27*ec779b8eSAndroid Build Coastguard Worker constexpr std::array<int, 0> kLookAheadMs{};
28*ec779b8eSAndroid Build Coastguard Worker #endif
29*ec779b8eSAndroid Build Coastguard Worker 
30*ec779b8eSAndroid Build Coastguard Worker } // namespace
31*ec779b8eSAndroid Build Coastguard Worker 
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)32*ec779b8eSAndroid Build Coastguard Worker void LeastSquaresPredictor::add(int64_t atNs, const Pose3f& pose, const Twist3f& twist)
33*ec779b8eSAndroid Build Coastguard Worker {
34*ec779b8eSAndroid Build Coastguard Worker     (void)twist;
35*ec779b8eSAndroid Build Coastguard Worker     mLastAtNs = atNs;
36*ec779b8eSAndroid Build Coastguard Worker     mLastPose = pose;
37*ec779b8eSAndroid Build Coastguard Worker     const auto q = pose.rotation();
38*ec779b8eSAndroid Build Coastguard Worker     const double datNs = static_cast<double>(atNs);
39*ec779b8eSAndroid Build Coastguard Worker     mRw.add({datNs, q.w()});
40*ec779b8eSAndroid Build Coastguard Worker     mRx.add({datNs, q.x()});
41*ec779b8eSAndroid Build Coastguard Worker     mRy.add({datNs, q.y()});
42*ec779b8eSAndroid Build Coastguard Worker     mRz.add({datNs, q.z()});
43*ec779b8eSAndroid Build Coastguard Worker }
44*ec779b8eSAndroid Build Coastguard Worker 
predict(int64_t atNs) const45*ec779b8eSAndroid Build Coastguard Worker Pose3f LeastSquaresPredictor::predict(int64_t atNs) const
46*ec779b8eSAndroid Build Coastguard Worker {
47*ec779b8eSAndroid Build Coastguard Worker     if (mRw.getN() < kMinimumSamplesForPrediction) return mLastPose;
48*ec779b8eSAndroid Build Coastguard Worker 
49*ec779b8eSAndroid Build Coastguard Worker     /*
50*ec779b8eSAndroid Build Coastguard Worker      * Using parametric form, we have q(t) = { w(t), x(t), y(t), z(t) }.
51*ec779b8eSAndroid Build Coastguard Worker      * We compute the least squares prediction of w, x, y, z.
52*ec779b8eSAndroid Build Coastguard Worker      */
53*ec779b8eSAndroid Build Coastguard Worker     const double dLookahead = static_cast<double>(atNs);
54*ec779b8eSAndroid Build Coastguard Worker     Eigen::Quaternionf lsq(
55*ec779b8eSAndroid Build Coastguard Worker         mRw.getYFromX(dLookahead),
56*ec779b8eSAndroid Build Coastguard Worker         mRx.getYFromX(dLookahead),
57*ec779b8eSAndroid Build Coastguard Worker         mRy.getYFromX(dLookahead),
58*ec779b8eSAndroid Build Coastguard Worker         mRz.getYFromX(dLookahead));
59*ec779b8eSAndroid Build Coastguard Worker 
60*ec779b8eSAndroid Build Coastguard Worker      /*
61*ec779b8eSAndroid Build Coastguard Worker       * We cheat here, since the result lsq is the least squares prediction
62*ec779b8eSAndroid Build Coastguard Worker       * in H (arbitrary quaternion), not the least squares prediction in
63*ec779b8eSAndroid Build Coastguard Worker       * SO(3) (unit quaternion).
64*ec779b8eSAndroid Build Coastguard Worker       *
65*ec779b8eSAndroid Build Coastguard Worker       * In other words, the result for lsq is most likely not a unit quaternion.
66*ec779b8eSAndroid Build Coastguard Worker       * To solve this, we normalize, thereby selecting the closest unit quaternion
67*ec779b8eSAndroid Build Coastguard Worker       * in SO(3) to the prediction in H.
68*ec779b8eSAndroid Build Coastguard Worker       */
69*ec779b8eSAndroid Build Coastguard Worker     lsq.normalize();
70*ec779b8eSAndroid Build Coastguard Worker     return Pose3f(lsq);
71*ec779b8eSAndroid Build Coastguard Worker }
72*ec779b8eSAndroid Build Coastguard Worker 
reset()73*ec779b8eSAndroid Build Coastguard Worker void LeastSquaresPredictor::reset() {
74*ec779b8eSAndroid Build Coastguard Worker     mLastAtNs = {};
75*ec779b8eSAndroid Build Coastguard Worker     mLastPose = {};
76*ec779b8eSAndroid Build Coastguard Worker     mRw.reset();
77*ec779b8eSAndroid Build Coastguard Worker     mRx.reset();
78*ec779b8eSAndroid Build Coastguard Worker     mRy.reset();
79*ec779b8eSAndroid Build Coastguard Worker     mRz.reset();
80*ec779b8eSAndroid Build Coastguard Worker }
81*ec779b8eSAndroid Build Coastguard Worker 
toString(size_t index) const82*ec779b8eSAndroid Build Coastguard Worker std::string LeastSquaresPredictor::toString(size_t index) const {
83*ec779b8eSAndroid Build Coastguard Worker     std::string s(index, ' ');
84*ec779b8eSAndroid Build Coastguard Worker     s.append("LeastSquaresPredictor using alpha: ")
85*ec779b8eSAndroid Build Coastguard Worker         .append(std::to_string(mAlpha))
86*ec779b8eSAndroid Build Coastguard Worker         .append(" last pose: ")
87*ec779b8eSAndroid Build Coastguard Worker         .append(mLastPose.toString())
88*ec779b8eSAndroid Build Coastguard Worker         .append("\n");
89*ec779b8eSAndroid Build Coastguard Worker     return s;
90*ec779b8eSAndroid Build Coastguard Worker }
91*ec779b8eSAndroid Build Coastguard Worker 
92*ec779b8eSAndroid Build Coastguard Worker // Formatting
createDelimiterIdx(size_t predictors,size_t lookaheads)93*ec779b8eSAndroid Build Coastguard Worker static inline std::vector<size_t> createDelimiterIdx(size_t predictors, size_t lookaheads) {
94*ec779b8eSAndroid Build Coastguard Worker     if (lookaheads == 0) return {};
95*ec779b8eSAndroid Build Coastguard Worker     --lookaheads;
96*ec779b8eSAndroid Build Coastguard Worker     std::vector<size_t> delimiterIdx(lookaheads);
97*ec779b8eSAndroid Build Coastguard Worker     for (size_t i = 0; i < lookaheads; ++i) {
98*ec779b8eSAndroid Build Coastguard Worker         delimiterIdx[i] = (i + 1) * predictors;
99*ec779b8eSAndroid Build Coastguard Worker     }
100*ec779b8eSAndroid Build Coastguard Worker     return delimiterIdx;
101*ec779b8eSAndroid Build Coastguard Worker }
102*ec779b8eSAndroid Build Coastguard Worker 
PosePredictor()103*ec779b8eSAndroid Build Coastguard Worker PosePredictor::PosePredictor()
104*ec779b8eSAndroid Build Coastguard Worker     : mPredictors{
105*ec779b8eSAndroid Build Coastguard Worker             // First predictors must match switch in getCurrentPredictor()
106*ec779b8eSAndroid Build Coastguard Worker             std::make_shared<LastPredictor>(),
107*ec779b8eSAndroid Build Coastguard Worker             std::make_shared<TwistPredictor>(),
108*ec779b8eSAndroid Build Coastguard Worker             std::make_shared<LeastSquaresPredictor>(),
109*ec779b8eSAndroid Build Coastguard Worker             // After this, can place additional predictors here for comparison such as
110*ec779b8eSAndroid Build Coastguard Worker             // std::make_shared<LeastSquaresPredictor>(0.25),
111*ec779b8eSAndroid Build Coastguard Worker         }
112*ec779b8eSAndroid Build Coastguard Worker     , mLookaheadMs(kLookAheadMs.begin(), kLookAheadMs.end())
113*ec779b8eSAndroid Build Coastguard Worker     , mVerifiers(std::size(mLookaheadMs) * std::size(mPredictors))
114*ec779b8eSAndroid Build Coastguard Worker     , mDelimiterIdx(createDelimiterIdx(std::size(mPredictors), std::size(mLookaheadMs)))
115*ec779b8eSAndroid Build Coastguard Worker     , mPredictionRecorder(
116*ec779b8eSAndroid Build Coastguard Worker         std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */,
117*ec779b8eSAndroid Build Coastguard Worker         mDelimiterIdx)
118*ec779b8eSAndroid Build Coastguard Worker     , mPredictionDurableRecorder(
119*ec779b8eSAndroid Build Coastguard Worker         std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */,
120*ec779b8eSAndroid Build Coastguard Worker         mDelimiterIdx)
121*ec779b8eSAndroid Build Coastguard Worker     {
122*ec779b8eSAndroid Build Coastguard Worker }
123*ec779b8eSAndroid Build Coastguard Worker 
predict(int64_t timestampNs,const Pose3f & pose,const Twist3f & twist,float predictionDurationNs)124*ec779b8eSAndroid Build Coastguard Worker Pose3f PosePredictor::predict(
125*ec779b8eSAndroid Build Coastguard Worker         int64_t timestampNs, const Pose3f& pose, const Twist3f& twist, float predictionDurationNs)
126*ec779b8eSAndroid Build Coastguard Worker {
127*ec779b8eSAndroid Build Coastguard Worker     if (timestampNs - mLastTimestampNs > kMaximumSampleIntervalBeforeResetNs) {
128*ec779b8eSAndroid Build Coastguard Worker         for (const auto& predictor : mPredictors) {
129*ec779b8eSAndroid Build Coastguard Worker             predictor->reset();
130*ec779b8eSAndroid Build Coastguard Worker         }
131*ec779b8eSAndroid Build Coastguard Worker         ++mResets;
132*ec779b8eSAndroid Build Coastguard Worker     }
133*ec779b8eSAndroid Build Coastguard Worker     mLastTimestampNs = timestampNs;
134*ec779b8eSAndroid Build Coastguard Worker 
135*ec779b8eSAndroid Build Coastguard Worker     auto selectedPredictor = getCurrentPredictor();
136*ec779b8eSAndroid Build Coastguard Worker     if constexpr (kEnableVerification) {
137*ec779b8eSAndroid Build Coastguard Worker         // Update all Predictors
138*ec779b8eSAndroid Build Coastguard Worker         for (const auto& predictor : mPredictors) {
139*ec779b8eSAndroid Build Coastguard Worker             predictor->add(timestampNs, pose, twist);
140*ec779b8eSAndroid Build Coastguard Worker         }
141*ec779b8eSAndroid Build Coastguard Worker 
142*ec779b8eSAndroid Build Coastguard Worker         // Update Verifiers and calculate errors
143*ec779b8eSAndroid Build Coastguard Worker         std::vector<float> error(std::size(mVerifiers));
144*ec779b8eSAndroid Build Coastguard Worker         for (size_t i = 0; i < mLookaheadMs.size(); ++i) {
145*ec779b8eSAndroid Build Coastguard Worker             constexpr float RADIAN_TO_DEGREES = 180 / M_PI;
146*ec779b8eSAndroid Build Coastguard Worker             const int64_t atNs =
147*ec779b8eSAndroid Build Coastguard Worker                     timestampNs + mLookaheadMs[i] * PosePredictorVerifier::kMillisToNanos;
148*ec779b8eSAndroid Build Coastguard Worker 
149*ec779b8eSAndroid Build Coastguard Worker             for (size_t j = 0; j < mPredictors.size(); ++j) {
150*ec779b8eSAndroid Build Coastguard Worker                 const size_t idx = i * std::size(mPredictors) + j;
151*ec779b8eSAndroid Build Coastguard Worker                 mVerifiers[idx].verifyActualPose(timestampNs, pose);
152*ec779b8eSAndroid Build Coastguard Worker                 mVerifiers[idx].addPredictedPose(atNs, mPredictors[j]->predict(atNs));
153*ec779b8eSAndroid Build Coastguard Worker                 error[idx] =  RADIAN_TO_DEGREES * mVerifiers[idx].lastError();
154*ec779b8eSAndroid Build Coastguard Worker             }
155*ec779b8eSAndroid Build Coastguard Worker         }
156*ec779b8eSAndroid Build Coastguard Worker         // Record errors
157*ec779b8eSAndroid Build Coastguard Worker         mPredictionRecorder.record(error);
158*ec779b8eSAndroid Build Coastguard Worker         mPredictionDurableRecorder.record(error);
159*ec779b8eSAndroid Build Coastguard Worker     } else /* constexpr */ {
160*ec779b8eSAndroid Build Coastguard Worker         selectedPredictor->add(timestampNs, pose, twist);
161*ec779b8eSAndroid Build Coastguard Worker     }
162*ec779b8eSAndroid Build Coastguard Worker 
163*ec779b8eSAndroid Build Coastguard Worker     // Deliver prediction
164*ec779b8eSAndroid Build Coastguard Worker     const int64_t predictionTimeNs = timestampNs + (int64_t)predictionDurationNs;
165*ec779b8eSAndroid Build Coastguard Worker     return selectedPredictor->predict(predictionTimeNs);
166*ec779b8eSAndroid Build Coastguard Worker }
167*ec779b8eSAndroid Build Coastguard Worker 
setPosePredictorType(PosePredictorType type)168*ec779b8eSAndroid Build Coastguard Worker void PosePredictor::setPosePredictorType(PosePredictorType type) {
169*ec779b8eSAndroid Build Coastguard Worker     if (!isValidPosePredictorType(type)) return;
170*ec779b8eSAndroid Build Coastguard Worker     if (type == mSetType) return;
171*ec779b8eSAndroid Build Coastguard Worker     mSetType = type;
172*ec779b8eSAndroid Build Coastguard Worker     if (type == android::media::PosePredictorType::AUTO) {
173*ec779b8eSAndroid Build Coastguard Worker         type = android::media::PosePredictorType::LEAST_SQUARES;
174*ec779b8eSAndroid Build Coastguard Worker     }
175*ec779b8eSAndroid Build Coastguard Worker     if (type != mCurrentType) {
176*ec779b8eSAndroid Build Coastguard Worker         mCurrentType = type;
177*ec779b8eSAndroid Build Coastguard Worker         if constexpr (!kEnableVerification) {
178*ec779b8eSAndroid Build Coastguard Worker             // Verification keeps all predictors up-to-date.
179*ec779b8eSAndroid Build Coastguard Worker             // If we don't enable verification, we must reset the current predictor.
180*ec779b8eSAndroid Build Coastguard Worker             getCurrentPredictor()->reset();
181*ec779b8eSAndroid Build Coastguard Worker         }
182*ec779b8eSAndroid Build Coastguard Worker     }
183*ec779b8eSAndroid Build Coastguard Worker }
184*ec779b8eSAndroid Build Coastguard Worker 
toString(size_t index) const185*ec779b8eSAndroid Build Coastguard Worker std::string PosePredictor::toString(size_t index) const {
186*ec779b8eSAndroid Build Coastguard Worker     std::string prefixSpace(index, ' ');
187*ec779b8eSAndroid Build Coastguard Worker     std::string ss(prefixSpace);
188*ec779b8eSAndroid Build Coastguard Worker     ss.append("PosePredictor:\n")
189*ec779b8eSAndroid Build Coastguard Worker         .append(prefixSpace)
190*ec779b8eSAndroid Build Coastguard Worker         .append(" Current Prediction Type: ")
191*ec779b8eSAndroid Build Coastguard Worker         .append(android::media::toString(mCurrentType))
192*ec779b8eSAndroid Build Coastguard Worker         .append("\n")
193*ec779b8eSAndroid Build Coastguard Worker         .append(prefixSpace)
194*ec779b8eSAndroid Build Coastguard Worker         .append(" Resets: ")
195*ec779b8eSAndroid Build Coastguard Worker         .append(std::to_string(mResets))
196*ec779b8eSAndroid Build Coastguard Worker         .append("\n")
197*ec779b8eSAndroid Build Coastguard Worker         .append(getCurrentPredictor()->toString(index + 1));
198*ec779b8eSAndroid Build Coastguard Worker     if constexpr (kEnableVerification) {
199*ec779b8eSAndroid Build Coastguard Worker         // dump verification
200*ec779b8eSAndroid Build Coastguard Worker         ss.append(prefixSpace)
201*ec779b8eSAndroid Build Coastguard Worker             .append(" Prediction abs error (L1) degrees [ type (");
202*ec779b8eSAndroid Build Coastguard Worker         for (size_t i = 0; i < mPredictors.size(); ++i) {
203*ec779b8eSAndroid Build Coastguard Worker             if (i > 0) ss.append(" , ");
204*ec779b8eSAndroid Build Coastguard Worker             ss.append(mPredictors[i]->name());
205*ec779b8eSAndroid Build Coastguard Worker         }
206*ec779b8eSAndroid Build Coastguard Worker         ss.append(" ) x ( ");
207*ec779b8eSAndroid Build Coastguard Worker         for (size_t i = 0; i < mLookaheadMs.size(); ++i) {
208*ec779b8eSAndroid Build Coastguard Worker             if (i > 0) ss.append(" : ");
209*ec779b8eSAndroid Build Coastguard Worker             ss.append(std::to_string(mLookaheadMs[i]));
210*ec779b8eSAndroid Build Coastguard Worker         }
211*ec779b8eSAndroid Build Coastguard Worker         std::vector<float> cumulativeAverageErrors(std::size(mVerifiers));
212*ec779b8eSAndroid Build Coastguard Worker         for (size_t i = 0; i < cumulativeAverageErrors.size(); ++i) {
213*ec779b8eSAndroid Build Coastguard Worker             cumulativeAverageErrors[i] = mVerifiers[i].cumulativeAverageError();
214*ec779b8eSAndroid Build Coastguard Worker         }
215*ec779b8eSAndroid Build Coastguard Worker         ss.append(" ) ms ]\n")
216*ec779b8eSAndroid Build Coastguard Worker             .append(prefixSpace)
217*ec779b8eSAndroid Build Coastguard Worker             .append("  Cumulative Average Error:\n")
218*ec779b8eSAndroid Build Coastguard Worker             .append(prefixSpace)
219*ec779b8eSAndroid Build Coastguard Worker             .append("   ")
220*ec779b8eSAndroid Build Coastguard Worker             .append(VectorRecorder::toString(cumulativeAverageErrors, mDelimiterIdx, "%.3g"))
221*ec779b8eSAndroid Build Coastguard Worker             .append("\n")
222*ec779b8eSAndroid Build Coastguard Worker             .append(prefixSpace)
223*ec779b8eSAndroid Build Coastguard Worker             .append("  PerMinuteHistory:\n")
224*ec779b8eSAndroid Build Coastguard Worker             .append(mPredictionDurableRecorder.toString(index + 3))
225*ec779b8eSAndroid Build Coastguard Worker             .append(prefixSpace)
226*ec779b8eSAndroid Build Coastguard Worker             .append("  PerSecondHistory:\n")
227*ec779b8eSAndroid Build Coastguard Worker             .append(mPredictionRecorder.toString(index + 3));
228*ec779b8eSAndroid Build Coastguard Worker     }
229*ec779b8eSAndroid Build Coastguard Worker     return ss;
230*ec779b8eSAndroid Build Coastguard Worker }
231*ec779b8eSAndroid Build Coastguard Worker 
getCurrentPredictor() const232*ec779b8eSAndroid Build Coastguard Worker std::shared_ptr<PredictorBase> PosePredictor::getCurrentPredictor() const {
233*ec779b8eSAndroid Build Coastguard Worker     // we don't use a map here, we look up directly
234*ec779b8eSAndroid Build Coastguard Worker     switch (mCurrentType) {
235*ec779b8eSAndroid Build Coastguard Worker     default:
236*ec779b8eSAndroid Build Coastguard Worker     case android::media::PosePredictorType::LAST:
237*ec779b8eSAndroid Build Coastguard Worker         return mPredictors[0];
238*ec779b8eSAndroid Build Coastguard Worker     case android::media::PosePredictorType::TWIST:
239*ec779b8eSAndroid Build Coastguard Worker         return mPredictors[1];
240*ec779b8eSAndroid Build Coastguard Worker     case android::media::PosePredictorType::AUTO: // shouldn't occur here.
241*ec779b8eSAndroid Build Coastguard Worker     case android::media::PosePredictorType::LEAST_SQUARES:
242*ec779b8eSAndroid Build Coastguard Worker         return mPredictors[2];
243*ec779b8eSAndroid Build Coastguard Worker     }
244*ec779b8eSAndroid Build Coastguard Worker }
245*ec779b8eSAndroid Build Coastguard Worker 
246*ec779b8eSAndroid Build Coastguard Worker } // namespace android::media
247