1*ec779b8eSAndroid Build Coastguard Worker /*
2*ec779b8eSAndroid Build Coastguard Worker * Copyright (C) 2023 The Android Open Source Project
3*ec779b8eSAndroid Build Coastguard Worker *
4*ec779b8eSAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*ec779b8eSAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*ec779b8eSAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*ec779b8eSAndroid Build Coastguard Worker *
8*ec779b8eSAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*ec779b8eSAndroid Build Coastguard Worker *
10*ec779b8eSAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*ec779b8eSAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*ec779b8eSAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*ec779b8eSAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*ec779b8eSAndroid Build Coastguard Worker * limitations under the License.
15*ec779b8eSAndroid Build Coastguard Worker */
16*ec779b8eSAndroid Build Coastguard Worker
17*ec779b8eSAndroid Build Coastguard Worker #include "PosePredictor.h"
18*ec779b8eSAndroid Build Coastguard Worker
19*ec779b8eSAndroid Build Coastguard Worker namespace android::media {
20*ec779b8eSAndroid Build Coastguard Worker
21*ec779b8eSAndroid Build Coastguard Worker namespace {
22*ec779b8eSAndroid Build Coastguard Worker #ifdef ENABLE_VERIFICATION
23*ec779b8eSAndroid Build Coastguard Worker constexpr bool kEnableVerification = true;
24*ec779b8eSAndroid Build Coastguard Worker constexpr std::array<int, 3> kLookAheadMs{ 50, 100, 200 };
25*ec779b8eSAndroid Build Coastguard Worker #else
26*ec779b8eSAndroid Build Coastguard Worker constexpr bool kEnableVerification = false;
27*ec779b8eSAndroid Build Coastguard Worker constexpr std::array<int, 0> kLookAheadMs{};
28*ec779b8eSAndroid Build Coastguard Worker #endif
29*ec779b8eSAndroid Build Coastguard Worker
30*ec779b8eSAndroid Build Coastguard Worker } // namespace
31*ec779b8eSAndroid Build Coastguard Worker
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)32*ec779b8eSAndroid Build Coastguard Worker void LeastSquaresPredictor::add(int64_t atNs, const Pose3f& pose, const Twist3f& twist)
33*ec779b8eSAndroid Build Coastguard Worker {
34*ec779b8eSAndroid Build Coastguard Worker (void)twist;
35*ec779b8eSAndroid Build Coastguard Worker mLastAtNs = atNs;
36*ec779b8eSAndroid Build Coastguard Worker mLastPose = pose;
37*ec779b8eSAndroid Build Coastguard Worker const auto q = pose.rotation();
38*ec779b8eSAndroid Build Coastguard Worker const double datNs = static_cast<double>(atNs);
39*ec779b8eSAndroid Build Coastguard Worker mRw.add({datNs, q.w()});
40*ec779b8eSAndroid Build Coastguard Worker mRx.add({datNs, q.x()});
41*ec779b8eSAndroid Build Coastguard Worker mRy.add({datNs, q.y()});
42*ec779b8eSAndroid Build Coastguard Worker mRz.add({datNs, q.z()});
43*ec779b8eSAndroid Build Coastguard Worker }
44*ec779b8eSAndroid Build Coastguard Worker
predict(int64_t atNs) const45*ec779b8eSAndroid Build Coastguard Worker Pose3f LeastSquaresPredictor::predict(int64_t atNs) const
46*ec779b8eSAndroid Build Coastguard Worker {
47*ec779b8eSAndroid Build Coastguard Worker if (mRw.getN() < kMinimumSamplesForPrediction) return mLastPose;
48*ec779b8eSAndroid Build Coastguard Worker
49*ec779b8eSAndroid Build Coastguard Worker /*
50*ec779b8eSAndroid Build Coastguard Worker * Using parametric form, we have q(t) = { w(t), x(t), y(t), z(t) }.
51*ec779b8eSAndroid Build Coastguard Worker * We compute the least squares prediction of w, x, y, z.
52*ec779b8eSAndroid Build Coastguard Worker */
53*ec779b8eSAndroid Build Coastguard Worker const double dLookahead = static_cast<double>(atNs);
54*ec779b8eSAndroid Build Coastguard Worker Eigen::Quaternionf lsq(
55*ec779b8eSAndroid Build Coastguard Worker mRw.getYFromX(dLookahead),
56*ec779b8eSAndroid Build Coastguard Worker mRx.getYFromX(dLookahead),
57*ec779b8eSAndroid Build Coastguard Worker mRy.getYFromX(dLookahead),
58*ec779b8eSAndroid Build Coastguard Worker mRz.getYFromX(dLookahead));
59*ec779b8eSAndroid Build Coastguard Worker
60*ec779b8eSAndroid Build Coastguard Worker /*
61*ec779b8eSAndroid Build Coastguard Worker * We cheat here, since the result lsq is the least squares prediction
62*ec779b8eSAndroid Build Coastguard Worker * in H (arbitrary quaternion), not the least squares prediction in
63*ec779b8eSAndroid Build Coastguard Worker * SO(3) (unit quaternion).
64*ec779b8eSAndroid Build Coastguard Worker *
65*ec779b8eSAndroid Build Coastguard Worker * In other words, the result for lsq is most likely not a unit quaternion.
66*ec779b8eSAndroid Build Coastguard Worker * To solve this, we normalize, thereby selecting the closest unit quaternion
67*ec779b8eSAndroid Build Coastguard Worker * in SO(3) to the prediction in H.
68*ec779b8eSAndroid Build Coastguard Worker */
69*ec779b8eSAndroid Build Coastguard Worker lsq.normalize();
70*ec779b8eSAndroid Build Coastguard Worker return Pose3f(lsq);
71*ec779b8eSAndroid Build Coastguard Worker }
72*ec779b8eSAndroid Build Coastguard Worker
reset()73*ec779b8eSAndroid Build Coastguard Worker void LeastSquaresPredictor::reset() {
74*ec779b8eSAndroid Build Coastguard Worker mLastAtNs = {};
75*ec779b8eSAndroid Build Coastguard Worker mLastPose = {};
76*ec779b8eSAndroid Build Coastguard Worker mRw.reset();
77*ec779b8eSAndroid Build Coastguard Worker mRx.reset();
78*ec779b8eSAndroid Build Coastguard Worker mRy.reset();
79*ec779b8eSAndroid Build Coastguard Worker mRz.reset();
80*ec779b8eSAndroid Build Coastguard Worker }
81*ec779b8eSAndroid Build Coastguard Worker
toString(size_t index) const82*ec779b8eSAndroid Build Coastguard Worker std::string LeastSquaresPredictor::toString(size_t index) const {
83*ec779b8eSAndroid Build Coastguard Worker std::string s(index, ' ');
84*ec779b8eSAndroid Build Coastguard Worker s.append("LeastSquaresPredictor using alpha: ")
85*ec779b8eSAndroid Build Coastguard Worker .append(std::to_string(mAlpha))
86*ec779b8eSAndroid Build Coastguard Worker .append(" last pose: ")
87*ec779b8eSAndroid Build Coastguard Worker .append(mLastPose.toString())
88*ec779b8eSAndroid Build Coastguard Worker .append("\n");
89*ec779b8eSAndroid Build Coastguard Worker return s;
90*ec779b8eSAndroid Build Coastguard Worker }
91*ec779b8eSAndroid Build Coastguard Worker
92*ec779b8eSAndroid Build Coastguard Worker // Formatting
createDelimiterIdx(size_t predictors,size_t lookaheads)93*ec779b8eSAndroid Build Coastguard Worker static inline std::vector<size_t> createDelimiterIdx(size_t predictors, size_t lookaheads) {
94*ec779b8eSAndroid Build Coastguard Worker if (lookaheads == 0) return {};
95*ec779b8eSAndroid Build Coastguard Worker --lookaheads;
96*ec779b8eSAndroid Build Coastguard Worker std::vector<size_t> delimiterIdx(lookaheads);
97*ec779b8eSAndroid Build Coastguard Worker for (size_t i = 0; i < lookaheads; ++i) {
98*ec779b8eSAndroid Build Coastguard Worker delimiterIdx[i] = (i + 1) * predictors;
99*ec779b8eSAndroid Build Coastguard Worker }
100*ec779b8eSAndroid Build Coastguard Worker return delimiterIdx;
101*ec779b8eSAndroid Build Coastguard Worker }
102*ec779b8eSAndroid Build Coastguard Worker
PosePredictor()103*ec779b8eSAndroid Build Coastguard Worker PosePredictor::PosePredictor()
104*ec779b8eSAndroid Build Coastguard Worker : mPredictors{
105*ec779b8eSAndroid Build Coastguard Worker // First predictors must match switch in getCurrentPredictor()
106*ec779b8eSAndroid Build Coastguard Worker std::make_shared<LastPredictor>(),
107*ec779b8eSAndroid Build Coastguard Worker std::make_shared<TwistPredictor>(),
108*ec779b8eSAndroid Build Coastguard Worker std::make_shared<LeastSquaresPredictor>(),
109*ec779b8eSAndroid Build Coastguard Worker // After this, can place additional predictors here for comparison such as
110*ec779b8eSAndroid Build Coastguard Worker // std::make_shared<LeastSquaresPredictor>(0.25),
111*ec779b8eSAndroid Build Coastguard Worker }
112*ec779b8eSAndroid Build Coastguard Worker , mLookaheadMs(kLookAheadMs.begin(), kLookAheadMs.end())
113*ec779b8eSAndroid Build Coastguard Worker , mVerifiers(std::size(mLookaheadMs) * std::size(mPredictors))
114*ec779b8eSAndroid Build Coastguard Worker , mDelimiterIdx(createDelimiterIdx(std::size(mPredictors), std::size(mLookaheadMs)))
115*ec779b8eSAndroid Build Coastguard Worker , mPredictionRecorder(
116*ec779b8eSAndroid Build Coastguard Worker std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */,
117*ec779b8eSAndroid Build Coastguard Worker mDelimiterIdx)
118*ec779b8eSAndroid Build Coastguard Worker , mPredictionDurableRecorder(
119*ec779b8eSAndroid Build Coastguard Worker std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */,
120*ec779b8eSAndroid Build Coastguard Worker mDelimiterIdx)
121*ec779b8eSAndroid Build Coastguard Worker {
122*ec779b8eSAndroid Build Coastguard Worker }
123*ec779b8eSAndroid Build Coastguard Worker
predict(int64_t timestampNs,const Pose3f & pose,const Twist3f & twist,float predictionDurationNs)124*ec779b8eSAndroid Build Coastguard Worker Pose3f PosePredictor::predict(
125*ec779b8eSAndroid Build Coastguard Worker int64_t timestampNs, const Pose3f& pose, const Twist3f& twist, float predictionDurationNs)
126*ec779b8eSAndroid Build Coastguard Worker {
127*ec779b8eSAndroid Build Coastguard Worker if (timestampNs - mLastTimestampNs > kMaximumSampleIntervalBeforeResetNs) {
128*ec779b8eSAndroid Build Coastguard Worker for (const auto& predictor : mPredictors) {
129*ec779b8eSAndroid Build Coastguard Worker predictor->reset();
130*ec779b8eSAndroid Build Coastguard Worker }
131*ec779b8eSAndroid Build Coastguard Worker ++mResets;
132*ec779b8eSAndroid Build Coastguard Worker }
133*ec779b8eSAndroid Build Coastguard Worker mLastTimestampNs = timestampNs;
134*ec779b8eSAndroid Build Coastguard Worker
135*ec779b8eSAndroid Build Coastguard Worker auto selectedPredictor = getCurrentPredictor();
136*ec779b8eSAndroid Build Coastguard Worker if constexpr (kEnableVerification) {
137*ec779b8eSAndroid Build Coastguard Worker // Update all Predictors
138*ec779b8eSAndroid Build Coastguard Worker for (const auto& predictor : mPredictors) {
139*ec779b8eSAndroid Build Coastguard Worker predictor->add(timestampNs, pose, twist);
140*ec779b8eSAndroid Build Coastguard Worker }
141*ec779b8eSAndroid Build Coastguard Worker
142*ec779b8eSAndroid Build Coastguard Worker // Update Verifiers and calculate errors
143*ec779b8eSAndroid Build Coastguard Worker std::vector<float> error(std::size(mVerifiers));
144*ec779b8eSAndroid Build Coastguard Worker for (size_t i = 0; i < mLookaheadMs.size(); ++i) {
145*ec779b8eSAndroid Build Coastguard Worker constexpr float RADIAN_TO_DEGREES = 180 / M_PI;
146*ec779b8eSAndroid Build Coastguard Worker const int64_t atNs =
147*ec779b8eSAndroid Build Coastguard Worker timestampNs + mLookaheadMs[i] * PosePredictorVerifier::kMillisToNanos;
148*ec779b8eSAndroid Build Coastguard Worker
149*ec779b8eSAndroid Build Coastguard Worker for (size_t j = 0; j < mPredictors.size(); ++j) {
150*ec779b8eSAndroid Build Coastguard Worker const size_t idx = i * std::size(mPredictors) + j;
151*ec779b8eSAndroid Build Coastguard Worker mVerifiers[idx].verifyActualPose(timestampNs, pose);
152*ec779b8eSAndroid Build Coastguard Worker mVerifiers[idx].addPredictedPose(atNs, mPredictors[j]->predict(atNs));
153*ec779b8eSAndroid Build Coastguard Worker error[idx] = RADIAN_TO_DEGREES * mVerifiers[idx].lastError();
154*ec779b8eSAndroid Build Coastguard Worker }
155*ec779b8eSAndroid Build Coastguard Worker }
156*ec779b8eSAndroid Build Coastguard Worker // Record errors
157*ec779b8eSAndroid Build Coastguard Worker mPredictionRecorder.record(error);
158*ec779b8eSAndroid Build Coastguard Worker mPredictionDurableRecorder.record(error);
159*ec779b8eSAndroid Build Coastguard Worker } else /* constexpr */ {
160*ec779b8eSAndroid Build Coastguard Worker selectedPredictor->add(timestampNs, pose, twist);
161*ec779b8eSAndroid Build Coastguard Worker }
162*ec779b8eSAndroid Build Coastguard Worker
163*ec779b8eSAndroid Build Coastguard Worker // Deliver prediction
164*ec779b8eSAndroid Build Coastguard Worker const int64_t predictionTimeNs = timestampNs + (int64_t)predictionDurationNs;
165*ec779b8eSAndroid Build Coastguard Worker return selectedPredictor->predict(predictionTimeNs);
166*ec779b8eSAndroid Build Coastguard Worker }
167*ec779b8eSAndroid Build Coastguard Worker
setPosePredictorType(PosePredictorType type)168*ec779b8eSAndroid Build Coastguard Worker void PosePredictor::setPosePredictorType(PosePredictorType type) {
169*ec779b8eSAndroid Build Coastguard Worker if (!isValidPosePredictorType(type)) return;
170*ec779b8eSAndroid Build Coastguard Worker if (type == mSetType) return;
171*ec779b8eSAndroid Build Coastguard Worker mSetType = type;
172*ec779b8eSAndroid Build Coastguard Worker if (type == android::media::PosePredictorType::AUTO) {
173*ec779b8eSAndroid Build Coastguard Worker type = android::media::PosePredictorType::LEAST_SQUARES;
174*ec779b8eSAndroid Build Coastguard Worker }
175*ec779b8eSAndroid Build Coastguard Worker if (type != mCurrentType) {
176*ec779b8eSAndroid Build Coastguard Worker mCurrentType = type;
177*ec779b8eSAndroid Build Coastguard Worker if constexpr (!kEnableVerification) {
178*ec779b8eSAndroid Build Coastguard Worker // Verification keeps all predictors up-to-date.
179*ec779b8eSAndroid Build Coastguard Worker // If we don't enable verification, we must reset the current predictor.
180*ec779b8eSAndroid Build Coastguard Worker getCurrentPredictor()->reset();
181*ec779b8eSAndroid Build Coastguard Worker }
182*ec779b8eSAndroid Build Coastguard Worker }
183*ec779b8eSAndroid Build Coastguard Worker }
184*ec779b8eSAndroid Build Coastguard Worker
toString(size_t index) const185*ec779b8eSAndroid Build Coastguard Worker std::string PosePredictor::toString(size_t index) const {
186*ec779b8eSAndroid Build Coastguard Worker std::string prefixSpace(index, ' ');
187*ec779b8eSAndroid Build Coastguard Worker std::string ss(prefixSpace);
188*ec779b8eSAndroid Build Coastguard Worker ss.append("PosePredictor:\n")
189*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
190*ec779b8eSAndroid Build Coastguard Worker .append(" Current Prediction Type: ")
191*ec779b8eSAndroid Build Coastguard Worker .append(android::media::toString(mCurrentType))
192*ec779b8eSAndroid Build Coastguard Worker .append("\n")
193*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
194*ec779b8eSAndroid Build Coastguard Worker .append(" Resets: ")
195*ec779b8eSAndroid Build Coastguard Worker .append(std::to_string(mResets))
196*ec779b8eSAndroid Build Coastguard Worker .append("\n")
197*ec779b8eSAndroid Build Coastguard Worker .append(getCurrentPredictor()->toString(index + 1));
198*ec779b8eSAndroid Build Coastguard Worker if constexpr (kEnableVerification) {
199*ec779b8eSAndroid Build Coastguard Worker // dump verification
200*ec779b8eSAndroid Build Coastguard Worker ss.append(prefixSpace)
201*ec779b8eSAndroid Build Coastguard Worker .append(" Prediction abs error (L1) degrees [ type (");
202*ec779b8eSAndroid Build Coastguard Worker for (size_t i = 0; i < mPredictors.size(); ++i) {
203*ec779b8eSAndroid Build Coastguard Worker if (i > 0) ss.append(" , ");
204*ec779b8eSAndroid Build Coastguard Worker ss.append(mPredictors[i]->name());
205*ec779b8eSAndroid Build Coastguard Worker }
206*ec779b8eSAndroid Build Coastguard Worker ss.append(" ) x ( ");
207*ec779b8eSAndroid Build Coastguard Worker for (size_t i = 0; i < mLookaheadMs.size(); ++i) {
208*ec779b8eSAndroid Build Coastguard Worker if (i > 0) ss.append(" : ");
209*ec779b8eSAndroid Build Coastguard Worker ss.append(std::to_string(mLookaheadMs[i]));
210*ec779b8eSAndroid Build Coastguard Worker }
211*ec779b8eSAndroid Build Coastguard Worker std::vector<float> cumulativeAverageErrors(std::size(mVerifiers));
212*ec779b8eSAndroid Build Coastguard Worker for (size_t i = 0; i < cumulativeAverageErrors.size(); ++i) {
213*ec779b8eSAndroid Build Coastguard Worker cumulativeAverageErrors[i] = mVerifiers[i].cumulativeAverageError();
214*ec779b8eSAndroid Build Coastguard Worker }
215*ec779b8eSAndroid Build Coastguard Worker ss.append(" ) ms ]\n")
216*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
217*ec779b8eSAndroid Build Coastguard Worker .append(" Cumulative Average Error:\n")
218*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
219*ec779b8eSAndroid Build Coastguard Worker .append(" ")
220*ec779b8eSAndroid Build Coastguard Worker .append(VectorRecorder::toString(cumulativeAverageErrors, mDelimiterIdx, "%.3g"))
221*ec779b8eSAndroid Build Coastguard Worker .append("\n")
222*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
223*ec779b8eSAndroid Build Coastguard Worker .append(" PerMinuteHistory:\n")
224*ec779b8eSAndroid Build Coastguard Worker .append(mPredictionDurableRecorder.toString(index + 3))
225*ec779b8eSAndroid Build Coastguard Worker .append(prefixSpace)
226*ec779b8eSAndroid Build Coastguard Worker .append(" PerSecondHistory:\n")
227*ec779b8eSAndroid Build Coastguard Worker .append(mPredictionRecorder.toString(index + 3));
228*ec779b8eSAndroid Build Coastguard Worker }
229*ec779b8eSAndroid Build Coastguard Worker return ss;
230*ec779b8eSAndroid Build Coastguard Worker }
231*ec779b8eSAndroid Build Coastguard Worker
getCurrentPredictor() const232*ec779b8eSAndroid Build Coastguard Worker std::shared_ptr<PredictorBase> PosePredictor::getCurrentPredictor() const {
233*ec779b8eSAndroid Build Coastguard Worker // we don't use a map here, we look up directly
234*ec779b8eSAndroid Build Coastguard Worker switch (mCurrentType) {
235*ec779b8eSAndroid Build Coastguard Worker default:
236*ec779b8eSAndroid Build Coastguard Worker case android::media::PosePredictorType::LAST:
237*ec779b8eSAndroid Build Coastguard Worker return mPredictors[0];
238*ec779b8eSAndroid Build Coastguard Worker case android::media::PosePredictorType::TWIST:
239*ec779b8eSAndroid Build Coastguard Worker return mPredictors[1];
240*ec779b8eSAndroid Build Coastguard Worker case android::media::PosePredictorType::AUTO: // shouldn't occur here.
241*ec779b8eSAndroid Build Coastguard Worker case android::media::PosePredictorType::LEAST_SQUARES:
242*ec779b8eSAndroid Build Coastguard Worker return mPredictors[2];
243*ec779b8eSAndroid Build Coastguard Worker }
244*ec779b8eSAndroid Build Coastguard Worker }
245*ec779b8eSAndroid Build Coastguard Worker
246*ec779b8eSAndroid Build Coastguard Worker } // namespace android::media
247