xref: /aosp_15_r20/frameworks/native/libs/input/MotionPredictorMetricsManager.cpp (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1*38e8c45fSAndroid Build Coastguard Worker /*
2*38e8c45fSAndroid Build Coastguard Worker  * Copyright 2023 The Android Open Source Project
3*38e8c45fSAndroid Build Coastguard Worker  *
4*38e8c45fSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*38e8c45fSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*38e8c45fSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*38e8c45fSAndroid Build Coastguard Worker  *
8*38e8c45fSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*38e8c45fSAndroid Build Coastguard Worker  *
10*38e8c45fSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*38e8c45fSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*38e8c45fSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*38e8c45fSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*38e8c45fSAndroid Build Coastguard Worker  * limitations under the License.
15*38e8c45fSAndroid Build Coastguard Worker  */
16*38e8c45fSAndroid Build Coastguard Worker 
17*38e8c45fSAndroid Build Coastguard Worker #define LOG_TAG "MotionPredictorMetricsManager"
18*38e8c45fSAndroid Build Coastguard Worker 
19*38e8c45fSAndroid Build Coastguard Worker #include <input/MotionPredictorMetricsManager.h>
20*38e8c45fSAndroid Build Coastguard Worker 
21*38e8c45fSAndroid Build Coastguard Worker #include <algorithm>
22*38e8c45fSAndroid Build Coastguard Worker 
23*38e8c45fSAndroid Build Coastguard Worker #include <android-base/logging.h>
24*38e8c45fSAndroid Build Coastguard Worker #ifdef __ANDROID__
25*38e8c45fSAndroid Build Coastguard Worker #include <statslog_libinput.h>
26*38e8c45fSAndroid Build Coastguard Worker #endif // __ANDROID__
27*38e8c45fSAndroid Build Coastguard Worker 
28*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Core"
29*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Geometry"
30*38e8c45fSAndroid Build Coastguard Worker 
31*38e8c45fSAndroid Build Coastguard Worker namespace android {
32*38e8c45fSAndroid Build Coastguard Worker namespace {
33*38e8c45fSAndroid Build Coastguard Worker 
34*38e8c45fSAndroid Build Coastguard Worker inline constexpr int NANOS_PER_SECOND = 1'000'000'000; // nanoseconds per second
35*38e8c45fSAndroid Build Coastguard Worker inline constexpr int NANOS_PER_MILLIS = 1'000'000;     // nanoseconds per millisecond
36*38e8c45fSAndroid Build Coastguard Worker 
37*38e8c45fSAndroid Build Coastguard Worker // Velocity threshold at which we report "high-velocity" metrics, in pixels per second.
38*38e8c45fSAndroid Build Coastguard Worker // This value was selected from manual experimentation, as a threshold that separates "fast"
39*38e8c45fSAndroid Build Coastguard Worker // (semi-sloppy) handwriting from more careful medium to slow handwriting.
40*38e8c45fSAndroid Build Coastguard Worker inline constexpr float HIGH_VELOCITY_THRESHOLD = 1100.0;
41*38e8c45fSAndroid Build Coastguard Worker 
42*38e8c45fSAndroid Build Coastguard Worker // Small value to add to the path length when computing scale-invariant error to avoid division by
43*38e8c45fSAndroid Build Coastguard Worker // zero.
44*38e8c45fSAndroid Build Coastguard Worker inline constexpr float PATH_LENGTH_EPSILON = 0.001;
45*38e8c45fSAndroid Build Coastguard Worker 
46*38e8c45fSAndroid Build Coastguard Worker } // namespace
47*38e8c45fSAndroid Build Coastguard Worker 
defaultReportAtomFunction(const MotionPredictorMetricsManager::AtomFields & atomFields)48*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::defaultReportAtomFunction(
49*38e8c45fSAndroid Build Coastguard Worker         const MotionPredictorMetricsManager::AtomFields& atomFields) {
50*38e8c45fSAndroid Build Coastguard Worker #ifdef __ANDROID__
51*38e8c45fSAndroid Build Coastguard Worker     android::libinput::stats_write(android::libinput::STYLUS_PREDICTION_METRICS_REPORTED,
52*38e8c45fSAndroid Build Coastguard Worker                                    /*stylus_vendor_id=*/0,
53*38e8c45fSAndroid Build Coastguard Worker                                    /*stylus_product_id=*/0,
54*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.deltaTimeBucketMilliseconds,
55*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.alongTrajectoryErrorMeanMillipixels,
56*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.alongTrajectoryErrorStdMillipixels,
57*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.offTrajectoryRmseMillipixels,
58*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.pressureRmseMilliunits,
59*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.highVelocityAlongTrajectoryRmse,
60*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.highVelocityOffTrajectoryRmse,
61*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.scaleInvariantAlongTrajectoryRmse,
62*38e8c45fSAndroid Build Coastguard Worker                                    atomFields.scaleInvariantOffTrajectoryRmse);
63*38e8c45fSAndroid Build Coastguard Worker #endif // __ANDROID__
64*38e8c45fSAndroid Build Coastguard Worker }
65*38e8c45fSAndroid Build Coastguard Worker 
MotionPredictorMetricsManager(nsecs_t predictionInterval,size_t maxNumPredictions,ReportAtomFunction reportAtomFunction)66*38e8c45fSAndroid Build Coastguard Worker MotionPredictorMetricsManager::MotionPredictorMetricsManager(
67*38e8c45fSAndroid Build Coastguard Worker         nsecs_t predictionInterval,
68*38e8c45fSAndroid Build Coastguard Worker         size_t maxNumPredictions,
69*38e8c45fSAndroid Build Coastguard Worker         ReportAtomFunction reportAtomFunction)
70*38e8c45fSAndroid Build Coastguard Worker       : mPredictionInterval(predictionInterval),
71*38e8c45fSAndroid Build Coastguard Worker         mMaxNumPredictions(maxNumPredictions),
72*38e8c45fSAndroid Build Coastguard Worker         mRecentGroundTruthPoints(maxNumPredictions + 1),
73*38e8c45fSAndroid Build Coastguard Worker         mAggregatedMetrics(maxNumPredictions),
74*38e8c45fSAndroid Build Coastguard Worker         mAtomFields(maxNumPredictions),
75*38e8c45fSAndroid Build Coastguard Worker         mReportAtomFunction(reportAtomFunction ? reportAtomFunction : defaultReportAtomFunction) {}
76*38e8c45fSAndroid Build Coastguard Worker 
onRecord(const MotionEvent & inputEvent)77*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) {
78*38e8c45fSAndroid Build Coastguard Worker     // Convert MotionEvent to GroundTruthPoint.
79*38e8c45fSAndroid Build Coastguard Worker     const PointerCoords* coords = inputEvent.getRawPointerCoords(/*pointerIndex=*/0);
80*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(coords == nullptr);
81*38e8c45fSAndroid Build Coastguard Worker     const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f{coords->getY(),
82*38e8c45fSAndroid Build Coastguard Worker                                                                          coords->getX()},
83*38e8c45fSAndroid Build Coastguard Worker                                              .pressure =
84*38e8c45fSAndroid Build Coastguard Worker                                                      inputEvent.getPressure(/*pointerIndex=*/0)},
85*38e8c45fSAndroid Build Coastguard Worker                                             .timestamp = inputEvent.getEventTime()};
86*38e8c45fSAndroid Build Coastguard Worker 
87*38e8c45fSAndroid Build Coastguard Worker     // Handle event based on action type.
88*38e8c45fSAndroid Build Coastguard Worker     switch (inputEvent.getActionMasked()) {
89*38e8c45fSAndroid Build Coastguard Worker         case AMOTION_EVENT_ACTION_DOWN: {
90*38e8c45fSAndroid Build Coastguard Worker             clearStrokeData();
91*38e8c45fSAndroid Build Coastguard Worker             incorporateNewGroundTruth(groundTruthPoint);
92*38e8c45fSAndroid Build Coastguard Worker             break;
93*38e8c45fSAndroid Build Coastguard Worker         }
94*38e8c45fSAndroid Build Coastguard Worker         case AMOTION_EVENT_ACTION_MOVE: {
95*38e8c45fSAndroid Build Coastguard Worker             incorporateNewGroundTruth(groundTruthPoint);
96*38e8c45fSAndroid Build Coastguard Worker             break;
97*38e8c45fSAndroid Build Coastguard Worker         }
98*38e8c45fSAndroid Build Coastguard Worker         case AMOTION_EVENT_ACTION_UP:
99*38e8c45fSAndroid Build Coastguard Worker         case AMOTION_EVENT_ACTION_CANCEL: {
100*38e8c45fSAndroid Build Coastguard Worker             // Only expect meaningful predictions when given at least two input points.
101*38e8c45fSAndroid Build Coastguard Worker             if (mRecentGroundTruthPoints.size() >= 2) {
102*38e8c45fSAndroid Build Coastguard Worker                 computeAtomFields();
103*38e8c45fSAndroid Build Coastguard Worker                 reportMetrics();
104*38e8c45fSAndroid Build Coastguard Worker             }
105*38e8c45fSAndroid Build Coastguard Worker             break;
106*38e8c45fSAndroid Build Coastguard Worker         }
107*38e8c45fSAndroid Build Coastguard Worker     }
108*38e8c45fSAndroid Build Coastguard Worker }
109*38e8c45fSAndroid Build Coastguard Worker 
110*38e8c45fSAndroid Build Coastguard Worker // Adds new predictions to mRecentPredictions and maintains the invariant that elements are
111*38e8c45fSAndroid Build Coastguard Worker // sorted in ascending order of targetTimestamp.
onPredict(const MotionEvent & predictionEvent)112*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::onPredict(const MotionEvent& predictionEvent) {
113*38e8c45fSAndroid Build Coastguard Worker     const size_t numPredictions = predictionEvent.getHistorySize() + 1;
114*38e8c45fSAndroid Build Coastguard Worker     if (numPredictions > mMaxNumPredictions) {
115*38e8c45fSAndroid Build Coastguard Worker         LOG(WARNING) << "numPredictions (" << numPredictions << ") > mMaxNumPredictions ("
116*38e8c45fSAndroid Build Coastguard Worker                      << mMaxNumPredictions << "). Ignoring extra predictions in metrics.";
117*38e8c45fSAndroid Build Coastguard Worker     }
118*38e8c45fSAndroid Build Coastguard Worker     for (size_t i = 0; (i < numPredictions) && (i < mMaxNumPredictions); ++i) {
119*38e8c45fSAndroid Build Coastguard Worker         // Convert MotionEvent to PredictionPoint.
120*38e8c45fSAndroid Build Coastguard Worker         const PointerCoords* coords =
121*38e8c45fSAndroid Build Coastguard Worker                 predictionEvent.getHistoricalRawPointerCoords(/*pointerIndex=*/0, i);
122*38e8c45fSAndroid Build Coastguard Worker         LOG_ALWAYS_FATAL_IF(coords == nullptr);
123*38e8c45fSAndroid Build Coastguard Worker         const nsecs_t targetTimestamp = predictionEvent.getHistoricalEventTime(i);
124*38e8c45fSAndroid Build Coastguard Worker         mRecentPredictions.push_back(
125*38e8c45fSAndroid Build Coastguard Worker                 PredictionPoint{{.position = Eigen::Vector2f{coords->getY(), coords->getX()},
126*38e8c45fSAndroid Build Coastguard Worker                                  .pressure =
127*38e8c45fSAndroid Build Coastguard Worker                                          predictionEvent.getHistoricalPressure(/*pointerIndex=*/0,
128*38e8c45fSAndroid Build Coastguard Worker                                                                                i)},
129*38e8c45fSAndroid Build Coastguard Worker                                 .originTimestamp = mRecentGroundTruthPoints.back().timestamp,
130*38e8c45fSAndroid Build Coastguard Worker                                 .targetTimestamp = targetTimestamp});
131*38e8c45fSAndroid Build Coastguard Worker     }
132*38e8c45fSAndroid Build Coastguard Worker 
133*38e8c45fSAndroid Build Coastguard Worker     std::sort(mRecentPredictions.begin(), mRecentPredictions.end());
134*38e8c45fSAndroid Build Coastguard Worker }
135*38e8c45fSAndroid Build Coastguard Worker 
clearStrokeData()136*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::clearStrokeData() {
137*38e8c45fSAndroid Build Coastguard Worker     mRecentGroundTruthPoints.clear();
138*38e8c45fSAndroid Build Coastguard Worker     mRecentPredictions.clear();
139*38e8c45fSAndroid Build Coastguard Worker     std::fill(mAggregatedMetrics.begin(), mAggregatedMetrics.end(), AggregatedStrokeMetrics{});
140*38e8c45fSAndroid Build Coastguard Worker     std::fill(mAtomFields.begin(), mAtomFields.end(), AtomFields{});
141*38e8c45fSAndroid Build Coastguard Worker }
142*38e8c45fSAndroid Build Coastguard Worker 
incorporateNewGroundTruth(const GroundTruthPoint & groundTruthPoint)143*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::incorporateNewGroundTruth(
144*38e8c45fSAndroid Build Coastguard Worker         const GroundTruthPoint& groundTruthPoint) {
145*38e8c45fSAndroid Build Coastguard Worker     // Note: this removes the oldest point if `mRecentGroundTruthPoints` is already at capacity.
146*38e8c45fSAndroid Build Coastguard Worker     mRecentGroundTruthPoints.pushBack(groundTruthPoint);
147*38e8c45fSAndroid Build Coastguard Worker 
148*38e8c45fSAndroid Build Coastguard Worker     // Remove outdated predictions – those that can never be matched with the current or any future
149*38e8c45fSAndroid Build Coastguard Worker     // ground truth points. We use fuzzy association for the timestamps here, because ground truth
150*38e8c45fSAndroid Build Coastguard Worker     // and prediction timestamps may not be perfectly synchronized.
151*38e8c45fSAndroid Build Coastguard Worker     const nsecs_t fuzzy_association_time_delta = mPredictionInterval / 4;
152*38e8c45fSAndroid Build Coastguard Worker     const auto firstCurrentIt =
153*38e8c45fSAndroid Build Coastguard Worker             std::find_if(mRecentPredictions.begin(), mRecentPredictions.end(),
154*38e8c45fSAndroid Build Coastguard Worker                          [&groundTruthPoint,
155*38e8c45fSAndroid Build Coastguard Worker                           fuzzy_association_time_delta](const PredictionPoint& prediction) {
156*38e8c45fSAndroid Build Coastguard Worker                              return prediction.targetTimestamp >
157*38e8c45fSAndroid Build Coastguard Worker                                      groundTruthPoint.timestamp - fuzzy_association_time_delta;
158*38e8c45fSAndroid Build Coastguard Worker                          });
159*38e8c45fSAndroid Build Coastguard Worker     mRecentPredictions.erase(mRecentPredictions.begin(), firstCurrentIt);
160*38e8c45fSAndroid Build Coastguard Worker 
161*38e8c45fSAndroid Build Coastguard Worker     // Fuzzily match the new ground truth's timestamp to recent predictions' targetTimestamp and
162*38e8c45fSAndroid Build Coastguard Worker     // update the corresponding metrics.
163*38e8c45fSAndroid Build Coastguard Worker     for (const PredictionPoint& prediction : mRecentPredictions) {
164*38e8c45fSAndroid Build Coastguard Worker         if ((prediction.targetTimestamp >
165*38e8c45fSAndroid Build Coastguard Worker              groundTruthPoint.timestamp - fuzzy_association_time_delta) &&
166*38e8c45fSAndroid Build Coastguard Worker             (prediction.targetTimestamp <
167*38e8c45fSAndroid Build Coastguard Worker              groundTruthPoint.timestamp + fuzzy_association_time_delta)) {
168*38e8c45fSAndroid Build Coastguard Worker             updateAggregatedMetrics(prediction);
169*38e8c45fSAndroid Build Coastguard Worker         }
170*38e8c45fSAndroid Build Coastguard Worker     }
171*38e8c45fSAndroid Build Coastguard Worker }
172*38e8c45fSAndroid Build Coastguard Worker 
updateAggregatedMetrics(const PredictionPoint & predictionPoint)173*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::updateAggregatedMetrics(
174*38e8c45fSAndroid Build Coastguard Worker         const PredictionPoint& predictionPoint) {
175*38e8c45fSAndroid Build Coastguard Worker     if (mRecentGroundTruthPoints.size() < 2) {
176*38e8c45fSAndroid Build Coastguard Worker         return;
177*38e8c45fSAndroid Build Coastguard Worker     }
178*38e8c45fSAndroid Build Coastguard Worker 
179*38e8c45fSAndroid Build Coastguard Worker     const GroundTruthPoint& latestGroundTruthPoint = mRecentGroundTruthPoints.back();
180*38e8c45fSAndroid Build Coastguard Worker     const GroundTruthPoint& previousGroundTruthPoint =
181*38e8c45fSAndroid Build Coastguard Worker             mRecentGroundTruthPoints[mRecentGroundTruthPoints.size() - 2];
182*38e8c45fSAndroid Build Coastguard Worker     // Calculate prediction error vector.
183*38e8c45fSAndroid Build Coastguard Worker     const Eigen::Vector2f groundTruthTrajectory =
184*38e8c45fSAndroid Build Coastguard Worker             latestGroundTruthPoint.position - previousGroundTruthPoint.position;
185*38e8c45fSAndroid Build Coastguard Worker     const Eigen::Vector2f predictionTrajectory =
186*38e8c45fSAndroid Build Coastguard Worker             predictionPoint.position - previousGroundTruthPoint.position;
187*38e8c45fSAndroid Build Coastguard Worker     const Eigen::Vector2f predictionError = predictionTrajectory - groundTruthTrajectory;
188*38e8c45fSAndroid Build Coastguard Worker 
189*38e8c45fSAndroid Build Coastguard Worker     // By default, prediction error counts fully as both off-trajectory and along-trajectory error.
190*38e8c45fSAndroid Build Coastguard Worker     // This serves as the fallback when the two most recent ground truth points are equal.
191*38e8c45fSAndroid Build Coastguard Worker     const float predictionErrorNorm = predictionError.norm();
192*38e8c45fSAndroid Build Coastguard Worker     float alongTrajectoryError = predictionErrorNorm;
193*38e8c45fSAndroid Build Coastguard Worker     float offTrajectoryError = predictionErrorNorm;
194*38e8c45fSAndroid Build Coastguard Worker     if (groundTruthTrajectory.squaredNorm() > 0) {
195*38e8c45fSAndroid Build Coastguard Worker         // Rotate the prediction error vector by the angle of the ground truth trajectory vector.
196*38e8c45fSAndroid Build Coastguard Worker         // This yields a vector whose first component is the along-trajectory error and whose
197*38e8c45fSAndroid Build Coastguard Worker         // second component is the off-trajectory error.
198*38e8c45fSAndroid Build Coastguard Worker         const float theta = std::atan2(groundTruthTrajectory[1], groundTruthTrajectory[0]);
199*38e8c45fSAndroid Build Coastguard Worker         const Eigen::Vector2f rotatedPredictionError = Eigen::Rotation2Df(-theta) * predictionError;
200*38e8c45fSAndroid Build Coastguard Worker         alongTrajectoryError = rotatedPredictionError[0];
201*38e8c45fSAndroid Build Coastguard Worker         offTrajectoryError = rotatedPredictionError[1];
202*38e8c45fSAndroid Build Coastguard Worker     }
203*38e8c45fSAndroid Build Coastguard Worker 
204*38e8c45fSAndroid Build Coastguard Worker     // Compute the multiple of mPredictionInterval nearest to the amount of time into the
205*38e8c45fSAndroid Build Coastguard Worker     // future being predicted. This serves as the time bucket index into mAggregatedMetrics.
206*38e8c45fSAndroid Build Coastguard Worker     const float timestampDeltaFloat =
207*38e8c45fSAndroid Build Coastguard Worker             static_cast<float>(predictionPoint.targetTimestamp - predictionPoint.originTimestamp);
208*38e8c45fSAndroid Build Coastguard Worker     const size_t tIndex =
209*38e8c45fSAndroid Build Coastguard Worker             static_cast<size_t>(std::round(timestampDeltaFloat / mPredictionInterval - 1));
210*38e8c45fSAndroid Build Coastguard Worker 
211*38e8c45fSAndroid Build Coastguard Worker     // Aggregate values into "general errors".
212*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].alongTrajectoryErrorSum += alongTrajectoryError;
213*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].alongTrajectorySumSquaredErrors +=
214*38e8c45fSAndroid Build Coastguard Worker             alongTrajectoryError * alongTrajectoryError;
215*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].offTrajectorySumSquaredErrors +=
216*38e8c45fSAndroid Build Coastguard Worker             offTrajectoryError * offTrajectoryError;
217*38e8c45fSAndroid Build Coastguard Worker     const float pressureError = predictionPoint.pressure - latestGroundTruthPoint.pressure;
218*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].pressureSumSquaredErrors += pressureError * pressureError;
219*38e8c45fSAndroid Build Coastguard Worker     ++mAggregatedMetrics[tIndex].generalErrorsCount;
220*38e8c45fSAndroid Build Coastguard Worker 
221*38e8c45fSAndroid Build Coastguard Worker     // Aggregate values into high-velocity metrics, if we are in one of the last two time buckets
222*38e8c45fSAndroid Build Coastguard Worker     // and the velocity is above the threshold. Velocity here is measured in pixels per second.
223*38e8c45fSAndroid Build Coastguard Worker     const float velocity = groundTruthTrajectory.norm() /
224*38e8c45fSAndroid Build Coastguard Worker             (static_cast<float>(latestGroundTruthPoint.timestamp -
225*38e8c45fSAndroid Build Coastguard Worker                                 previousGroundTruthPoint.timestamp) /
226*38e8c45fSAndroid Build Coastguard Worker              NANOS_PER_SECOND);
227*38e8c45fSAndroid Build Coastguard Worker     if ((tIndex + 2 >= mMaxNumPredictions) && (velocity > HIGH_VELOCITY_THRESHOLD)) {
228*38e8c45fSAndroid Build Coastguard Worker         mAggregatedMetrics[tIndex].highVelocityAlongTrajectorySse +=
229*38e8c45fSAndroid Build Coastguard Worker                 alongTrajectoryError * alongTrajectoryError;
230*38e8c45fSAndroid Build Coastguard Worker         mAggregatedMetrics[tIndex].highVelocityOffTrajectorySse +=
231*38e8c45fSAndroid Build Coastguard Worker                 offTrajectoryError * offTrajectoryError;
232*38e8c45fSAndroid Build Coastguard Worker         ++mAggregatedMetrics[tIndex].highVelocityErrorsCount;
233*38e8c45fSAndroid Build Coastguard Worker     }
234*38e8c45fSAndroid Build Coastguard Worker 
235*38e8c45fSAndroid Build Coastguard Worker     // Compute path length for scale-invariant errors.
236*38e8c45fSAndroid Build Coastguard Worker     float pathLength = 0;
237*38e8c45fSAndroid Build Coastguard Worker     for (size_t i = 1; i < mRecentGroundTruthPoints.size(); ++i) {
238*38e8c45fSAndroid Build Coastguard Worker         pathLength +=
239*38e8c45fSAndroid Build Coastguard Worker                 (mRecentGroundTruthPoints[i].position - mRecentGroundTruthPoints[i - 1].position)
240*38e8c45fSAndroid Build Coastguard Worker                         .norm();
241*38e8c45fSAndroid Build Coastguard Worker     }
242*38e8c45fSAndroid Build Coastguard Worker     // Avoid overweighting errors at the beginning of a stroke: compute the path length as if there
243*38e8c45fSAndroid Build Coastguard Worker     // were a full ground truth history by filling in missing segments with the average length.
244*38e8c45fSAndroid Build Coastguard Worker     // Note: the "- 1" is needed to translate from number of endpoints to number of segments.
245*38e8c45fSAndroid Build Coastguard Worker     pathLength *= static_cast<float>(mRecentGroundTruthPoints.capacity() - 1) /
246*38e8c45fSAndroid Build Coastguard Worker             (mRecentGroundTruthPoints.size() - 1);
247*38e8c45fSAndroid Build Coastguard Worker     pathLength += PATH_LENGTH_EPSILON; // Ensure path length is nonzero (>= PATH_LENGTH_EPSILON).
248*38e8c45fSAndroid Build Coastguard Worker 
249*38e8c45fSAndroid Build Coastguard Worker     // Compute and aggregate scale-invariant errors.
250*38e8c45fSAndroid Build Coastguard Worker     const float scaleInvariantAlongTrajectoryError = alongTrajectoryError / pathLength;
251*38e8c45fSAndroid Build Coastguard Worker     const float scaleInvariantOffTrajectoryError = offTrajectoryError / pathLength;
252*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].scaleInvariantAlongTrajectorySse +=
253*38e8c45fSAndroid Build Coastguard Worker             scaleInvariantAlongTrajectoryError * scaleInvariantAlongTrajectoryError;
254*38e8c45fSAndroid Build Coastguard Worker     mAggregatedMetrics[tIndex].scaleInvariantOffTrajectorySse +=
255*38e8c45fSAndroid Build Coastguard Worker             scaleInvariantOffTrajectoryError * scaleInvariantOffTrajectoryError;
256*38e8c45fSAndroid Build Coastguard Worker     ++mAggregatedMetrics[tIndex].scaleInvariantErrorsCount;
257*38e8c45fSAndroid Build Coastguard Worker }
258*38e8c45fSAndroid Build Coastguard Worker 
computeAtomFields()259*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::computeAtomFields() {
260*38e8c45fSAndroid Build Coastguard Worker     for (size_t i = 0; i < mAggregatedMetrics.size(); ++i) {
261*38e8c45fSAndroid Build Coastguard Worker         if (mAggregatedMetrics[i].generalErrorsCount == 0) {
262*38e8c45fSAndroid Build Coastguard Worker             // We have not received data corresponding to metrics for this time bucket.
263*38e8c45fSAndroid Build Coastguard Worker             continue;
264*38e8c45fSAndroid Build Coastguard Worker         }
265*38e8c45fSAndroid Build Coastguard Worker 
266*38e8c45fSAndroid Build Coastguard Worker         mAtomFields[i].deltaTimeBucketMilliseconds =
267*38e8c45fSAndroid Build Coastguard Worker                 static_cast<int>(mPredictionInterval / NANOS_PER_MILLIS * (i + 1));
268*38e8c45fSAndroid Build Coastguard Worker 
269*38e8c45fSAndroid Build Coastguard Worker         // Note: we need the "* 1000"s below because we report values in integral milli-units.
270*38e8c45fSAndroid Build Coastguard Worker 
271*38e8c45fSAndroid Build Coastguard Worker         { // General errors: reported for every time bucket.
272*38e8c45fSAndroid Build Coastguard Worker             const float alongTrajectoryErrorMean = mAggregatedMetrics[i].alongTrajectoryErrorSum /
273*38e8c45fSAndroid Build Coastguard Worker                     mAggregatedMetrics[i].generalErrorsCount;
274*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].alongTrajectoryErrorMeanMillipixels =
275*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(alongTrajectoryErrorMean * 1000);
276*38e8c45fSAndroid Build Coastguard Worker 
277*38e8c45fSAndroid Build Coastguard Worker             const float alongTrajectoryMse = mAggregatedMetrics[i].alongTrajectorySumSquaredErrors /
278*38e8c45fSAndroid Build Coastguard Worker                     mAggregatedMetrics[i].generalErrorsCount;
279*38e8c45fSAndroid Build Coastguard Worker             // Take the max with 0 to avoid negative values caused by numerical instability.
280*38e8c45fSAndroid Build Coastguard Worker             const float alongTrajectoryErrorVariance =
281*38e8c45fSAndroid Build Coastguard Worker                     std::max(0.0f,
282*38e8c45fSAndroid Build Coastguard Worker                              alongTrajectoryMse -
283*38e8c45fSAndroid Build Coastguard Worker                                      alongTrajectoryErrorMean * alongTrajectoryErrorMean);
284*38e8c45fSAndroid Build Coastguard Worker             const float alongTrajectoryErrorStd = std::sqrt(alongTrajectoryErrorVariance);
285*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].alongTrajectoryErrorStdMillipixels =
286*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(alongTrajectoryErrorStd * 1000);
287*38e8c45fSAndroid Build Coastguard Worker 
288*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].offTrajectorySumSquaredErrors < 0,
289*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].offTrajectorySumSquaredErrors = %f should "
290*38e8c45fSAndroid Build Coastguard Worker                                 "not be negative",
291*38e8c45fSAndroid Build Coastguard Worker                                 i, mAggregatedMetrics[i].offTrajectorySumSquaredErrors);
292*38e8c45fSAndroid Build Coastguard Worker             const float offTrajectoryRmse =
293*38e8c45fSAndroid Build Coastguard Worker                     std::sqrt(mAggregatedMetrics[i].offTrajectorySumSquaredErrors /
294*38e8c45fSAndroid Build Coastguard Worker                               mAggregatedMetrics[i].generalErrorsCount);
295*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].offTrajectoryRmseMillipixels =
296*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(offTrajectoryRmse * 1000);
297*38e8c45fSAndroid Build Coastguard Worker 
298*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].pressureSumSquaredErrors < 0,
299*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].pressureSumSquaredErrors = %f should not "
300*38e8c45fSAndroid Build Coastguard Worker                                 "be negative",
301*38e8c45fSAndroid Build Coastguard Worker                                 i, mAggregatedMetrics[i].pressureSumSquaredErrors);
302*38e8c45fSAndroid Build Coastguard Worker             const float pressureRmse = std::sqrt(mAggregatedMetrics[i].pressureSumSquaredErrors /
303*38e8c45fSAndroid Build Coastguard Worker                                                  mAggregatedMetrics[i].generalErrorsCount);
304*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].pressureRmseMilliunits = static_cast<int>(pressureRmse * 1000);
305*38e8c45fSAndroid Build Coastguard Worker         }
306*38e8c45fSAndroid Build Coastguard Worker 
307*38e8c45fSAndroid Build Coastguard Worker         // High-velocity errors: reported only for last two time buckets.
308*38e8c45fSAndroid Build Coastguard Worker         // Check if we are in one of the last two time buckets, and there is high-velocity data.
309*38e8c45fSAndroid Build Coastguard Worker         if ((i + 2 >= mMaxNumPredictions) && (mAggregatedMetrics[i].highVelocityErrorsCount > 0)) {
310*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityAlongTrajectorySse < 0,
311*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].highVelocityAlongTrajectorySse = %f "
312*38e8c45fSAndroid Build Coastguard Worker                                 "should not be negative",
313*38e8c45fSAndroid Build Coastguard Worker                                 i, mAggregatedMetrics[i].highVelocityAlongTrajectorySse);
314*38e8c45fSAndroid Build Coastguard Worker             const float alongTrajectoryRmse =
315*38e8c45fSAndroid Build Coastguard Worker                     std::sqrt(mAggregatedMetrics[i].highVelocityAlongTrajectorySse /
316*38e8c45fSAndroid Build Coastguard Worker                               mAggregatedMetrics[i].highVelocityErrorsCount);
317*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].highVelocityAlongTrajectoryRmse =
318*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(alongTrajectoryRmse * 1000);
319*38e8c45fSAndroid Build Coastguard Worker 
320*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityOffTrajectorySse < 0,
321*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].highVelocityOffTrajectorySse = %f should "
322*38e8c45fSAndroid Build Coastguard Worker                                 "not be negative",
323*38e8c45fSAndroid Build Coastguard Worker                                 i, mAggregatedMetrics[i].highVelocityOffTrajectorySse);
324*38e8c45fSAndroid Build Coastguard Worker             const float offTrajectoryRmse =
325*38e8c45fSAndroid Build Coastguard Worker                     std::sqrt(mAggregatedMetrics[i].highVelocityOffTrajectorySse /
326*38e8c45fSAndroid Build Coastguard Worker                               mAggregatedMetrics[i].highVelocityErrorsCount);
327*38e8c45fSAndroid Build Coastguard Worker             mAtomFields[i].highVelocityOffTrajectoryRmse =
328*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(offTrajectoryRmse * 1000);
329*38e8c45fSAndroid Build Coastguard Worker         }
330*38e8c45fSAndroid Build Coastguard Worker     }
331*38e8c45fSAndroid Build Coastguard Worker 
332*38e8c45fSAndroid Build Coastguard Worker     // Scale-invariant errors: the average scale-invariant error across all time buckets
333*38e8c45fSAndroid Build Coastguard Worker     // is reported in the last time bucket.
334*38e8c45fSAndroid Build Coastguard Worker     {
335*38e8c45fSAndroid Build Coastguard Worker         // Compute error averages.
336*38e8c45fSAndroid Build Coastguard Worker         float alongTrajectoryRmseSum = 0;
337*38e8c45fSAndroid Build Coastguard Worker         float offTrajectoryRmseSum = 0;
338*38e8c45fSAndroid Build Coastguard Worker         int bucket_count = 0;
339*38e8c45fSAndroid Build Coastguard Worker         for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) {
340*38e8c45fSAndroid Build Coastguard Worker             if (mAggregatedMetrics[j].scaleInvariantErrorsCount == 0) {
341*38e8c45fSAndroid Build Coastguard Worker                 continue;
342*38e8c45fSAndroid Build Coastguard Worker             }
343*38e8c45fSAndroid Build Coastguard Worker 
344*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0,
345*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f "
346*38e8c45fSAndroid Build Coastguard Worker                                 "should not be negative",
347*38e8c45fSAndroid Build Coastguard Worker                                 j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse);
348*38e8c45fSAndroid Build Coastguard Worker             alongTrajectoryRmseSum +=
349*38e8c45fSAndroid Build Coastguard Worker                     std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse /
350*38e8c45fSAndroid Build Coastguard Worker                               mAggregatedMetrics[j].scaleInvariantErrorsCount);
351*38e8c45fSAndroid Build Coastguard Worker 
352*38e8c45fSAndroid Build Coastguard Worker             LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0,
353*38e8c45fSAndroid Build Coastguard Worker                                 "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f "
354*38e8c45fSAndroid Build Coastguard Worker                                 "should not be negative",
355*38e8c45fSAndroid Build Coastguard Worker                                 j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse);
356*38e8c45fSAndroid Build Coastguard Worker             offTrajectoryRmseSum += std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse /
357*38e8c45fSAndroid Build Coastguard Worker                                               mAggregatedMetrics[j].scaleInvariantErrorsCount);
358*38e8c45fSAndroid Build Coastguard Worker 
359*38e8c45fSAndroid Build Coastguard Worker             ++bucket_count;
360*38e8c45fSAndroid Build Coastguard Worker         }
361*38e8c45fSAndroid Build Coastguard Worker 
362*38e8c45fSAndroid Build Coastguard Worker         if (bucket_count > 0) {
363*38e8c45fSAndroid Build Coastguard Worker             const float averageAlongTrajectoryRmse = alongTrajectoryRmseSum / bucket_count;
364*38e8c45fSAndroid Build Coastguard Worker             mAtomFields.back().scaleInvariantAlongTrajectoryRmse =
365*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(averageAlongTrajectoryRmse * 1000);
366*38e8c45fSAndroid Build Coastguard Worker 
367*38e8c45fSAndroid Build Coastguard Worker             const float averageOffTrajectoryRmse = offTrajectoryRmseSum / bucket_count;
368*38e8c45fSAndroid Build Coastguard Worker             mAtomFields.back().scaleInvariantOffTrajectoryRmse =
369*38e8c45fSAndroid Build Coastguard Worker                     static_cast<int>(averageOffTrajectoryRmse * 1000);
370*38e8c45fSAndroid Build Coastguard Worker         }
371*38e8c45fSAndroid Build Coastguard Worker     }
372*38e8c45fSAndroid Build Coastguard Worker }
373*38e8c45fSAndroid Build Coastguard Worker 
reportMetrics()374*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::reportMetrics() {
375*38e8c45fSAndroid Build Coastguard Worker     LOG_ALWAYS_FATAL_IF(!mReportAtomFunction);
376*38e8c45fSAndroid Build Coastguard Worker     // Report one atom for each prediction time bucket.
377*38e8c45fSAndroid Build Coastguard Worker     for (size_t i = 0; i < mAtomFields.size(); ++i) {
378*38e8c45fSAndroid Build Coastguard Worker         mReportAtomFunction(mAtomFields[i]);
379*38e8c45fSAndroid Build Coastguard Worker     }
380*38e8c45fSAndroid Build Coastguard Worker }
381*38e8c45fSAndroid Build Coastguard Worker 
382*38e8c45fSAndroid Build Coastguard Worker } // namespace android
383