1*38e8c45fSAndroid Build Coastguard Worker /*
2*38e8c45fSAndroid Build Coastguard Worker * Copyright 2023 The Android Open Source Project
3*38e8c45fSAndroid Build Coastguard Worker *
4*38e8c45fSAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*38e8c45fSAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*38e8c45fSAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*38e8c45fSAndroid Build Coastguard Worker *
8*38e8c45fSAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*38e8c45fSAndroid Build Coastguard Worker *
10*38e8c45fSAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*38e8c45fSAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*38e8c45fSAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*38e8c45fSAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*38e8c45fSAndroid Build Coastguard Worker * limitations under the License.
15*38e8c45fSAndroid Build Coastguard Worker */
16*38e8c45fSAndroid Build Coastguard Worker
17*38e8c45fSAndroid Build Coastguard Worker #define LOG_TAG "MotionPredictorMetricsManager"
18*38e8c45fSAndroid Build Coastguard Worker
19*38e8c45fSAndroid Build Coastguard Worker #include <input/MotionPredictorMetricsManager.h>
20*38e8c45fSAndroid Build Coastguard Worker
21*38e8c45fSAndroid Build Coastguard Worker #include <algorithm>
22*38e8c45fSAndroid Build Coastguard Worker
23*38e8c45fSAndroid Build Coastguard Worker #include <android-base/logging.h>
24*38e8c45fSAndroid Build Coastguard Worker #ifdef __ANDROID__
25*38e8c45fSAndroid Build Coastguard Worker #include <statslog_libinput.h>
26*38e8c45fSAndroid Build Coastguard Worker #endif // __ANDROID__
27*38e8c45fSAndroid Build Coastguard Worker
28*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Core"
29*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Geometry"
30*38e8c45fSAndroid Build Coastguard Worker
31*38e8c45fSAndroid Build Coastguard Worker namespace android {
32*38e8c45fSAndroid Build Coastguard Worker namespace {
33*38e8c45fSAndroid Build Coastguard Worker
34*38e8c45fSAndroid Build Coastguard Worker inline constexpr int NANOS_PER_SECOND = 1'000'000'000; // nanoseconds per second
35*38e8c45fSAndroid Build Coastguard Worker inline constexpr int NANOS_PER_MILLIS = 1'000'000; // nanoseconds per millisecond
36*38e8c45fSAndroid Build Coastguard Worker
37*38e8c45fSAndroid Build Coastguard Worker // Velocity threshold at which we report "high-velocity" metrics, in pixels per second.
38*38e8c45fSAndroid Build Coastguard Worker // This value was selected from manual experimentation, as a threshold that separates "fast"
39*38e8c45fSAndroid Build Coastguard Worker // (semi-sloppy) handwriting from more careful medium to slow handwriting.
40*38e8c45fSAndroid Build Coastguard Worker inline constexpr float HIGH_VELOCITY_THRESHOLD = 1100.0;
41*38e8c45fSAndroid Build Coastguard Worker
42*38e8c45fSAndroid Build Coastguard Worker // Small value to add to the path length when computing scale-invariant error to avoid division by
43*38e8c45fSAndroid Build Coastguard Worker // zero.
44*38e8c45fSAndroid Build Coastguard Worker inline constexpr float PATH_LENGTH_EPSILON = 0.001;
45*38e8c45fSAndroid Build Coastguard Worker
46*38e8c45fSAndroid Build Coastguard Worker } // namespace
47*38e8c45fSAndroid Build Coastguard Worker
defaultReportAtomFunction(const MotionPredictorMetricsManager::AtomFields & atomFields)48*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::defaultReportAtomFunction(
49*38e8c45fSAndroid Build Coastguard Worker const MotionPredictorMetricsManager::AtomFields& atomFields) {
50*38e8c45fSAndroid Build Coastguard Worker #ifdef __ANDROID__
51*38e8c45fSAndroid Build Coastguard Worker android::libinput::stats_write(android::libinput::STYLUS_PREDICTION_METRICS_REPORTED,
52*38e8c45fSAndroid Build Coastguard Worker /*stylus_vendor_id=*/0,
53*38e8c45fSAndroid Build Coastguard Worker /*stylus_product_id=*/0,
54*38e8c45fSAndroid Build Coastguard Worker atomFields.deltaTimeBucketMilliseconds,
55*38e8c45fSAndroid Build Coastguard Worker atomFields.alongTrajectoryErrorMeanMillipixels,
56*38e8c45fSAndroid Build Coastguard Worker atomFields.alongTrajectoryErrorStdMillipixels,
57*38e8c45fSAndroid Build Coastguard Worker atomFields.offTrajectoryRmseMillipixels,
58*38e8c45fSAndroid Build Coastguard Worker atomFields.pressureRmseMilliunits,
59*38e8c45fSAndroid Build Coastguard Worker atomFields.highVelocityAlongTrajectoryRmse,
60*38e8c45fSAndroid Build Coastguard Worker atomFields.highVelocityOffTrajectoryRmse,
61*38e8c45fSAndroid Build Coastguard Worker atomFields.scaleInvariantAlongTrajectoryRmse,
62*38e8c45fSAndroid Build Coastguard Worker atomFields.scaleInvariantOffTrajectoryRmse);
63*38e8c45fSAndroid Build Coastguard Worker #endif // __ANDROID__
64*38e8c45fSAndroid Build Coastguard Worker }
65*38e8c45fSAndroid Build Coastguard Worker
MotionPredictorMetricsManager(nsecs_t predictionInterval,size_t maxNumPredictions,ReportAtomFunction reportAtomFunction)66*38e8c45fSAndroid Build Coastguard Worker MotionPredictorMetricsManager::MotionPredictorMetricsManager(
67*38e8c45fSAndroid Build Coastguard Worker nsecs_t predictionInterval,
68*38e8c45fSAndroid Build Coastguard Worker size_t maxNumPredictions,
69*38e8c45fSAndroid Build Coastguard Worker ReportAtomFunction reportAtomFunction)
70*38e8c45fSAndroid Build Coastguard Worker : mPredictionInterval(predictionInterval),
71*38e8c45fSAndroid Build Coastguard Worker mMaxNumPredictions(maxNumPredictions),
72*38e8c45fSAndroid Build Coastguard Worker mRecentGroundTruthPoints(maxNumPredictions + 1),
73*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics(maxNumPredictions),
74*38e8c45fSAndroid Build Coastguard Worker mAtomFields(maxNumPredictions),
75*38e8c45fSAndroid Build Coastguard Worker mReportAtomFunction(reportAtomFunction ? reportAtomFunction : defaultReportAtomFunction) {}
76*38e8c45fSAndroid Build Coastguard Worker
onRecord(const MotionEvent & inputEvent)77*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::onRecord(const MotionEvent& inputEvent) {
78*38e8c45fSAndroid Build Coastguard Worker // Convert MotionEvent to GroundTruthPoint.
79*38e8c45fSAndroid Build Coastguard Worker const PointerCoords* coords = inputEvent.getRawPointerCoords(/*pointerIndex=*/0);
80*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(coords == nullptr);
81*38e8c45fSAndroid Build Coastguard Worker const GroundTruthPoint groundTruthPoint{{.position = Eigen::Vector2f{coords->getY(),
82*38e8c45fSAndroid Build Coastguard Worker coords->getX()},
83*38e8c45fSAndroid Build Coastguard Worker .pressure =
84*38e8c45fSAndroid Build Coastguard Worker inputEvent.getPressure(/*pointerIndex=*/0)},
85*38e8c45fSAndroid Build Coastguard Worker .timestamp = inputEvent.getEventTime()};
86*38e8c45fSAndroid Build Coastguard Worker
87*38e8c45fSAndroid Build Coastguard Worker // Handle event based on action type.
88*38e8c45fSAndroid Build Coastguard Worker switch (inputEvent.getActionMasked()) {
89*38e8c45fSAndroid Build Coastguard Worker case AMOTION_EVENT_ACTION_DOWN: {
90*38e8c45fSAndroid Build Coastguard Worker clearStrokeData();
91*38e8c45fSAndroid Build Coastguard Worker incorporateNewGroundTruth(groundTruthPoint);
92*38e8c45fSAndroid Build Coastguard Worker break;
93*38e8c45fSAndroid Build Coastguard Worker }
94*38e8c45fSAndroid Build Coastguard Worker case AMOTION_EVENT_ACTION_MOVE: {
95*38e8c45fSAndroid Build Coastguard Worker incorporateNewGroundTruth(groundTruthPoint);
96*38e8c45fSAndroid Build Coastguard Worker break;
97*38e8c45fSAndroid Build Coastguard Worker }
98*38e8c45fSAndroid Build Coastguard Worker case AMOTION_EVENT_ACTION_UP:
99*38e8c45fSAndroid Build Coastguard Worker case AMOTION_EVENT_ACTION_CANCEL: {
100*38e8c45fSAndroid Build Coastguard Worker // Only expect meaningful predictions when given at least two input points.
101*38e8c45fSAndroid Build Coastguard Worker if (mRecentGroundTruthPoints.size() >= 2) {
102*38e8c45fSAndroid Build Coastguard Worker computeAtomFields();
103*38e8c45fSAndroid Build Coastguard Worker reportMetrics();
104*38e8c45fSAndroid Build Coastguard Worker }
105*38e8c45fSAndroid Build Coastguard Worker break;
106*38e8c45fSAndroid Build Coastguard Worker }
107*38e8c45fSAndroid Build Coastguard Worker }
108*38e8c45fSAndroid Build Coastguard Worker }
109*38e8c45fSAndroid Build Coastguard Worker
110*38e8c45fSAndroid Build Coastguard Worker // Adds new predictions to mRecentPredictions and maintains the invariant that elements are
111*38e8c45fSAndroid Build Coastguard Worker // sorted in ascending order of targetTimestamp.
onPredict(const MotionEvent & predictionEvent)112*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::onPredict(const MotionEvent& predictionEvent) {
113*38e8c45fSAndroid Build Coastguard Worker const size_t numPredictions = predictionEvent.getHistorySize() + 1;
114*38e8c45fSAndroid Build Coastguard Worker if (numPredictions > mMaxNumPredictions) {
115*38e8c45fSAndroid Build Coastguard Worker LOG(WARNING) << "numPredictions (" << numPredictions << ") > mMaxNumPredictions ("
116*38e8c45fSAndroid Build Coastguard Worker << mMaxNumPredictions << "). Ignoring extra predictions in metrics.";
117*38e8c45fSAndroid Build Coastguard Worker }
118*38e8c45fSAndroid Build Coastguard Worker for (size_t i = 0; (i < numPredictions) && (i < mMaxNumPredictions); ++i) {
119*38e8c45fSAndroid Build Coastguard Worker // Convert MotionEvent to PredictionPoint.
120*38e8c45fSAndroid Build Coastguard Worker const PointerCoords* coords =
121*38e8c45fSAndroid Build Coastguard Worker predictionEvent.getHistoricalRawPointerCoords(/*pointerIndex=*/0, i);
122*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(coords == nullptr);
123*38e8c45fSAndroid Build Coastguard Worker const nsecs_t targetTimestamp = predictionEvent.getHistoricalEventTime(i);
124*38e8c45fSAndroid Build Coastguard Worker mRecentPredictions.push_back(
125*38e8c45fSAndroid Build Coastguard Worker PredictionPoint{{.position = Eigen::Vector2f{coords->getY(), coords->getX()},
126*38e8c45fSAndroid Build Coastguard Worker .pressure =
127*38e8c45fSAndroid Build Coastguard Worker predictionEvent.getHistoricalPressure(/*pointerIndex=*/0,
128*38e8c45fSAndroid Build Coastguard Worker i)},
129*38e8c45fSAndroid Build Coastguard Worker .originTimestamp = mRecentGroundTruthPoints.back().timestamp,
130*38e8c45fSAndroid Build Coastguard Worker .targetTimestamp = targetTimestamp});
131*38e8c45fSAndroid Build Coastguard Worker }
132*38e8c45fSAndroid Build Coastguard Worker
133*38e8c45fSAndroid Build Coastguard Worker std::sort(mRecentPredictions.begin(), mRecentPredictions.end());
134*38e8c45fSAndroid Build Coastguard Worker }
135*38e8c45fSAndroid Build Coastguard Worker
clearStrokeData()136*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::clearStrokeData() {
137*38e8c45fSAndroid Build Coastguard Worker mRecentGroundTruthPoints.clear();
138*38e8c45fSAndroid Build Coastguard Worker mRecentPredictions.clear();
139*38e8c45fSAndroid Build Coastguard Worker std::fill(mAggregatedMetrics.begin(), mAggregatedMetrics.end(), AggregatedStrokeMetrics{});
140*38e8c45fSAndroid Build Coastguard Worker std::fill(mAtomFields.begin(), mAtomFields.end(), AtomFields{});
141*38e8c45fSAndroid Build Coastguard Worker }
142*38e8c45fSAndroid Build Coastguard Worker
incorporateNewGroundTruth(const GroundTruthPoint & groundTruthPoint)143*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::incorporateNewGroundTruth(
144*38e8c45fSAndroid Build Coastguard Worker const GroundTruthPoint& groundTruthPoint) {
145*38e8c45fSAndroid Build Coastguard Worker // Note: this removes the oldest point if `mRecentGroundTruthPoints` is already at capacity.
146*38e8c45fSAndroid Build Coastguard Worker mRecentGroundTruthPoints.pushBack(groundTruthPoint);
147*38e8c45fSAndroid Build Coastguard Worker
148*38e8c45fSAndroid Build Coastguard Worker // Remove outdated predictions – those that can never be matched with the current or any future
149*38e8c45fSAndroid Build Coastguard Worker // ground truth points. We use fuzzy association for the timestamps here, because ground truth
150*38e8c45fSAndroid Build Coastguard Worker // and prediction timestamps may not be perfectly synchronized.
151*38e8c45fSAndroid Build Coastguard Worker const nsecs_t fuzzy_association_time_delta = mPredictionInterval / 4;
152*38e8c45fSAndroid Build Coastguard Worker const auto firstCurrentIt =
153*38e8c45fSAndroid Build Coastguard Worker std::find_if(mRecentPredictions.begin(), mRecentPredictions.end(),
154*38e8c45fSAndroid Build Coastguard Worker [&groundTruthPoint,
155*38e8c45fSAndroid Build Coastguard Worker fuzzy_association_time_delta](const PredictionPoint& prediction) {
156*38e8c45fSAndroid Build Coastguard Worker return prediction.targetTimestamp >
157*38e8c45fSAndroid Build Coastguard Worker groundTruthPoint.timestamp - fuzzy_association_time_delta;
158*38e8c45fSAndroid Build Coastguard Worker });
159*38e8c45fSAndroid Build Coastguard Worker mRecentPredictions.erase(mRecentPredictions.begin(), firstCurrentIt);
160*38e8c45fSAndroid Build Coastguard Worker
161*38e8c45fSAndroid Build Coastguard Worker // Fuzzily match the new ground truth's timestamp to recent predictions' targetTimestamp and
162*38e8c45fSAndroid Build Coastguard Worker // update the corresponding metrics.
163*38e8c45fSAndroid Build Coastguard Worker for (const PredictionPoint& prediction : mRecentPredictions) {
164*38e8c45fSAndroid Build Coastguard Worker if ((prediction.targetTimestamp >
165*38e8c45fSAndroid Build Coastguard Worker groundTruthPoint.timestamp - fuzzy_association_time_delta) &&
166*38e8c45fSAndroid Build Coastguard Worker (prediction.targetTimestamp <
167*38e8c45fSAndroid Build Coastguard Worker groundTruthPoint.timestamp + fuzzy_association_time_delta)) {
168*38e8c45fSAndroid Build Coastguard Worker updateAggregatedMetrics(prediction);
169*38e8c45fSAndroid Build Coastguard Worker }
170*38e8c45fSAndroid Build Coastguard Worker }
171*38e8c45fSAndroid Build Coastguard Worker }
172*38e8c45fSAndroid Build Coastguard Worker
updateAggregatedMetrics(const PredictionPoint & predictionPoint)173*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::updateAggregatedMetrics(
174*38e8c45fSAndroid Build Coastguard Worker const PredictionPoint& predictionPoint) {
175*38e8c45fSAndroid Build Coastguard Worker if (mRecentGroundTruthPoints.size() < 2) {
176*38e8c45fSAndroid Build Coastguard Worker return;
177*38e8c45fSAndroid Build Coastguard Worker }
178*38e8c45fSAndroid Build Coastguard Worker
179*38e8c45fSAndroid Build Coastguard Worker const GroundTruthPoint& latestGroundTruthPoint = mRecentGroundTruthPoints.back();
180*38e8c45fSAndroid Build Coastguard Worker const GroundTruthPoint& previousGroundTruthPoint =
181*38e8c45fSAndroid Build Coastguard Worker mRecentGroundTruthPoints[mRecentGroundTruthPoints.size() - 2];
182*38e8c45fSAndroid Build Coastguard Worker // Calculate prediction error vector.
183*38e8c45fSAndroid Build Coastguard Worker const Eigen::Vector2f groundTruthTrajectory =
184*38e8c45fSAndroid Build Coastguard Worker latestGroundTruthPoint.position - previousGroundTruthPoint.position;
185*38e8c45fSAndroid Build Coastguard Worker const Eigen::Vector2f predictionTrajectory =
186*38e8c45fSAndroid Build Coastguard Worker predictionPoint.position - previousGroundTruthPoint.position;
187*38e8c45fSAndroid Build Coastguard Worker const Eigen::Vector2f predictionError = predictionTrajectory - groundTruthTrajectory;
188*38e8c45fSAndroid Build Coastguard Worker
189*38e8c45fSAndroid Build Coastguard Worker // By default, prediction error counts fully as both off-trajectory and along-trajectory error.
190*38e8c45fSAndroid Build Coastguard Worker // This serves as the fallback when the two most recent ground truth points are equal.
191*38e8c45fSAndroid Build Coastguard Worker const float predictionErrorNorm = predictionError.norm();
192*38e8c45fSAndroid Build Coastguard Worker float alongTrajectoryError = predictionErrorNorm;
193*38e8c45fSAndroid Build Coastguard Worker float offTrajectoryError = predictionErrorNorm;
194*38e8c45fSAndroid Build Coastguard Worker if (groundTruthTrajectory.squaredNorm() > 0) {
195*38e8c45fSAndroid Build Coastguard Worker // Rotate the prediction error vector by the angle of the ground truth trajectory vector.
196*38e8c45fSAndroid Build Coastguard Worker // This yields a vector whose first component is the along-trajectory error and whose
197*38e8c45fSAndroid Build Coastguard Worker // second component is the off-trajectory error.
198*38e8c45fSAndroid Build Coastguard Worker const float theta = std::atan2(groundTruthTrajectory[1], groundTruthTrajectory[0]);
199*38e8c45fSAndroid Build Coastguard Worker const Eigen::Vector2f rotatedPredictionError = Eigen::Rotation2Df(-theta) * predictionError;
200*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryError = rotatedPredictionError[0];
201*38e8c45fSAndroid Build Coastguard Worker offTrajectoryError = rotatedPredictionError[1];
202*38e8c45fSAndroid Build Coastguard Worker }
203*38e8c45fSAndroid Build Coastguard Worker
204*38e8c45fSAndroid Build Coastguard Worker // Compute the multiple of mPredictionInterval nearest to the amount of time into the
205*38e8c45fSAndroid Build Coastguard Worker // future being predicted. This serves as the time bucket index into mAggregatedMetrics.
206*38e8c45fSAndroid Build Coastguard Worker const float timestampDeltaFloat =
207*38e8c45fSAndroid Build Coastguard Worker static_cast<float>(predictionPoint.targetTimestamp - predictionPoint.originTimestamp);
208*38e8c45fSAndroid Build Coastguard Worker const size_t tIndex =
209*38e8c45fSAndroid Build Coastguard Worker static_cast<size_t>(std::round(timestampDeltaFloat / mPredictionInterval - 1));
210*38e8c45fSAndroid Build Coastguard Worker
211*38e8c45fSAndroid Build Coastguard Worker // Aggregate values into "general errors".
212*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].alongTrajectoryErrorSum += alongTrajectoryError;
213*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].alongTrajectorySumSquaredErrors +=
214*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryError * alongTrajectoryError;
215*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].offTrajectorySumSquaredErrors +=
216*38e8c45fSAndroid Build Coastguard Worker offTrajectoryError * offTrajectoryError;
217*38e8c45fSAndroid Build Coastguard Worker const float pressureError = predictionPoint.pressure - latestGroundTruthPoint.pressure;
218*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].pressureSumSquaredErrors += pressureError * pressureError;
219*38e8c45fSAndroid Build Coastguard Worker ++mAggregatedMetrics[tIndex].generalErrorsCount;
220*38e8c45fSAndroid Build Coastguard Worker
221*38e8c45fSAndroid Build Coastguard Worker // Aggregate values into high-velocity metrics, if we are in one of the last two time buckets
222*38e8c45fSAndroid Build Coastguard Worker // and the velocity is above the threshold. Velocity here is measured in pixels per second.
223*38e8c45fSAndroid Build Coastguard Worker const float velocity = groundTruthTrajectory.norm() /
224*38e8c45fSAndroid Build Coastguard Worker (static_cast<float>(latestGroundTruthPoint.timestamp -
225*38e8c45fSAndroid Build Coastguard Worker previousGroundTruthPoint.timestamp) /
226*38e8c45fSAndroid Build Coastguard Worker NANOS_PER_SECOND);
227*38e8c45fSAndroid Build Coastguard Worker if ((tIndex + 2 >= mMaxNumPredictions) && (velocity > HIGH_VELOCITY_THRESHOLD)) {
228*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].highVelocityAlongTrajectorySse +=
229*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryError * alongTrajectoryError;
230*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].highVelocityOffTrajectorySse +=
231*38e8c45fSAndroid Build Coastguard Worker offTrajectoryError * offTrajectoryError;
232*38e8c45fSAndroid Build Coastguard Worker ++mAggregatedMetrics[tIndex].highVelocityErrorsCount;
233*38e8c45fSAndroid Build Coastguard Worker }
234*38e8c45fSAndroid Build Coastguard Worker
235*38e8c45fSAndroid Build Coastguard Worker // Compute path length for scale-invariant errors.
236*38e8c45fSAndroid Build Coastguard Worker float pathLength = 0;
237*38e8c45fSAndroid Build Coastguard Worker for (size_t i = 1; i < mRecentGroundTruthPoints.size(); ++i) {
238*38e8c45fSAndroid Build Coastguard Worker pathLength +=
239*38e8c45fSAndroid Build Coastguard Worker (mRecentGroundTruthPoints[i].position - mRecentGroundTruthPoints[i - 1].position)
240*38e8c45fSAndroid Build Coastguard Worker .norm();
241*38e8c45fSAndroid Build Coastguard Worker }
242*38e8c45fSAndroid Build Coastguard Worker // Avoid overweighting errors at the beginning of a stroke: compute the path length as if there
243*38e8c45fSAndroid Build Coastguard Worker // were a full ground truth history by filling in missing segments with the average length.
244*38e8c45fSAndroid Build Coastguard Worker // Note: the "- 1" is needed to translate from number of endpoints to number of segments.
245*38e8c45fSAndroid Build Coastguard Worker pathLength *= static_cast<float>(mRecentGroundTruthPoints.capacity() - 1) /
246*38e8c45fSAndroid Build Coastguard Worker (mRecentGroundTruthPoints.size() - 1);
247*38e8c45fSAndroid Build Coastguard Worker pathLength += PATH_LENGTH_EPSILON; // Ensure path length is nonzero (>= PATH_LENGTH_EPSILON).
248*38e8c45fSAndroid Build Coastguard Worker
249*38e8c45fSAndroid Build Coastguard Worker // Compute and aggregate scale-invariant errors.
250*38e8c45fSAndroid Build Coastguard Worker const float scaleInvariantAlongTrajectoryError = alongTrajectoryError / pathLength;
251*38e8c45fSAndroid Build Coastguard Worker const float scaleInvariantOffTrajectoryError = offTrajectoryError / pathLength;
252*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].scaleInvariantAlongTrajectorySse +=
253*38e8c45fSAndroid Build Coastguard Worker scaleInvariantAlongTrajectoryError * scaleInvariantAlongTrajectoryError;
254*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[tIndex].scaleInvariantOffTrajectorySse +=
255*38e8c45fSAndroid Build Coastguard Worker scaleInvariantOffTrajectoryError * scaleInvariantOffTrajectoryError;
256*38e8c45fSAndroid Build Coastguard Worker ++mAggregatedMetrics[tIndex].scaleInvariantErrorsCount;
257*38e8c45fSAndroid Build Coastguard Worker }
258*38e8c45fSAndroid Build Coastguard Worker
computeAtomFields()259*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::computeAtomFields() {
260*38e8c45fSAndroid Build Coastguard Worker for (size_t i = 0; i < mAggregatedMetrics.size(); ++i) {
261*38e8c45fSAndroid Build Coastguard Worker if (mAggregatedMetrics[i].generalErrorsCount == 0) {
262*38e8c45fSAndroid Build Coastguard Worker // We have not received data corresponding to metrics for this time bucket.
263*38e8c45fSAndroid Build Coastguard Worker continue;
264*38e8c45fSAndroid Build Coastguard Worker }
265*38e8c45fSAndroid Build Coastguard Worker
266*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].deltaTimeBucketMilliseconds =
267*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(mPredictionInterval / NANOS_PER_MILLIS * (i + 1));
268*38e8c45fSAndroid Build Coastguard Worker
269*38e8c45fSAndroid Build Coastguard Worker // Note: we need the "* 1000"s below because we report values in integral milli-units.
270*38e8c45fSAndroid Build Coastguard Worker
271*38e8c45fSAndroid Build Coastguard Worker { // General errors: reported for every time bucket.
272*38e8c45fSAndroid Build Coastguard Worker const float alongTrajectoryErrorMean = mAggregatedMetrics[i].alongTrajectoryErrorSum /
273*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].generalErrorsCount;
274*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].alongTrajectoryErrorMeanMillipixels =
275*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(alongTrajectoryErrorMean * 1000);
276*38e8c45fSAndroid Build Coastguard Worker
277*38e8c45fSAndroid Build Coastguard Worker const float alongTrajectoryMse = mAggregatedMetrics[i].alongTrajectorySumSquaredErrors /
278*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].generalErrorsCount;
279*38e8c45fSAndroid Build Coastguard Worker // Take the max with 0 to avoid negative values caused by numerical instability.
280*38e8c45fSAndroid Build Coastguard Worker const float alongTrajectoryErrorVariance =
281*38e8c45fSAndroid Build Coastguard Worker std::max(0.0f,
282*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryMse -
283*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryErrorMean * alongTrajectoryErrorMean);
284*38e8c45fSAndroid Build Coastguard Worker const float alongTrajectoryErrorStd = std::sqrt(alongTrajectoryErrorVariance);
285*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].alongTrajectoryErrorStdMillipixels =
286*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(alongTrajectoryErrorStd * 1000);
287*38e8c45fSAndroid Build Coastguard Worker
288*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].offTrajectorySumSquaredErrors < 0,
289*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].offTrajectorySumSquaredErrors = %f should "
290*38e8c45fSAndroid Build Coastguard Worker "not be negative",
291*38e8c45fSAndroid Build Coastguard Worker i, mAggregatedMetrics[i].offTrajectorySumSquaredErrors);
292*38e8c45fSAndroid Build Coastguard Worker const float offTrajectoryRmse =
293*38e8c45fSAndroid Build Coastguard Worker std::sqrt(mAggregatedMetrics[i].offTrajectorySumSquaredErrors /
294*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].generalErrorsCount);
295*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].offTrajectoryRmseMillipixels =
296*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(offTrajectoryRmse * 1000);
297*38e8c45fSAndroid Build Coastguard Worker
298*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].pressureSumSquaredErrors < 0,
299*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].pressureSumSquaredErrors = %f should not "
300*38e8c45fSAndroid Build Coastguard Worker "be negative",
301*38e8c45fSAndroid Build Coastguard Worker i, mAggregatedMetrics[i].pressureSumSquaredErrors);
302*38e8c45fSAndroid Build Coastguard Worker const float pressureRmse = std::sqrt(mAggregatedMetrics[i].pressureSumSquaredErrors /
303*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].generalErrorsCount);
304*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].pressureRmseMilliunits = static_cast<int>(pressureRmse * 1000);
305*38e8c45fSAndroid Build Coastguard Worker }
306*38e8c45fSAndroid Build Coastguard Worker
307*38e8c45fSAndroid Build Coastguard Worker // High-velocity errors: reported only for last two time buckets.
308*38e8c45fSAndroid Build Coastguard Worker // Check if we are in one of the last two time buckets, and there is high-velocity data.
309*38e8c45fSAndroid Build Coastguard Worker if ((i + 2 >= mMaxNumPredictions) && (mAggregatedMetrics[i].highVelocityErrorsCount > 0)) {
310*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityAlongTrajectorySse < 0,
311*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].highVelocityAlongTrajectorySse = %f "
312*38e8c45fSAndroid Build Coastguard Worker "should not be negative",
313*38e8c45fSAndroid Build Coastguard Worker i, mAggregatedMetrics[i].highVelocityAlongTrajectorySse);
314*38e8c45fSAndroid Build Coastguard Worker const float alongTrajectoryRmse =
315*38e8c45fSAndroid Build Coastguard Worker std::sqrt(mAggregatedMetrics[i].highVelocityAlongTrajectorySse /
316*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].highVelocityErrorsCount);
317*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].highVelocityAlongTrajectoryRmse =
318*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(alongTrajectoryRmse * 1000);
319*38e8c45fSAndroid Build Coastguard Worker
320*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[i].highVelocityOffTrajectorySse < 0,
321*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].highVelocityOffTrajectorySse = %f should "
322*38e8c45fSAndroid Build Coastguard Worker "not be negative",
323*38e8c45fSAndroid Build Coastguard Worker i, mAggregatedMetrics[i].highVelocityOffTrajectorySse);
324*38e8c45fSAndroid Build Coastguard Worker const float offTrajectoryRmse =
325*38e8c45fSAndroid Build Coastguard Worker std::sqrt(mAggregatedMetrics[i].highVelocityOffTrajectorySse /
326*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[i].highVelocityErrorsCount);
327*38e8c45fSAndroid Build Coastguard Worker mAtomFields[i].highVelocityOffTrajectoryRmse =
328*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(offTrajectoryRmse * 1000);
329*38e8c45fSAndroid Build Coastguard Worker }
330*38e8c45fSAndroid Build Coastguard Worker }
331*38e8c45fSAndroid Build Coastguard Worker
332*38e8c45fSAndroid Build Coastguard Worker // Scale-invariant errors: the average scale-invariant error across all time buckets
333*38e8c45fSAndroid Build Coastguard Worker // is reported in the last time bucket.
334*38e8c45fSAndroid Build Coastguard Worker {
335*38e8c45fSAndroid Build Coastguard Worker // Compute error averages.
336*38e8c45fSAndroid Build Coastguard Worker float alongTrajectoryRmseSum = 0;
337*38e8c45fSAndroid Build Coastguard Worker float offTrajectoryRmseSum = 0;
338*38e8c45fSAndroid Build Coastguard Worker int bucket_count = 0;
339*38e8c45fSAndroid Build Coastguard Worker for (size_t j = 0; j < mAggregatedMetrics.size(); ++j) {
340*38e8c45fSAndroid Build Coastguard Worker if (mAggregatedMetrics[j].scaleInvariantErrorsCount == 0) {
341*38e8c45fSAndroid Build Coastguard Worker continue;
342*38e8c45fSAndroid Build Coastguard Worker }
343*38e8c45fSAndroid Build Coastguard Worker
344*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse < 0,
345*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].scaleInvariantAlongTrajectorySse = %f "
346*38e8c45fSAndroid Build Coastguard Worker "should not be negative",
347*38e8c45fSAndroid Build Coastguard Worker j, mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse);
348*38e8c45fSAndroid Build Coastguard Worker alongTrajectoryRmseSum +=
349*38e8c45fSAndroid Build Coastguard Worker std::sqrt(mAggregatedMetrics[j].scaleInvariantAlongTrajectorySse /
350*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[j].scaleInvariantErrorsCount);
351*38e8c45fSAndroid Build Coastguard Worker
352*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse < 0,
353*38e8c45fSAndroid Build Coastguard Worker "mAggregatedMetrics[%zu].scaleInvariantOffTrajectorySse = %f "
354*38e8c45fSAndroid Build Coastguard Worker "should not be negative",
355*38e8c45fSAndroid Build Coastguard Worker j, mAggregatedMetrics[j].scaleInvariantOffTrajectorySse);
356*38e8c45fSAndroid Build Coastguard Worker offTrajectoryRmseSum += std::sqrt(mAggregatedMetrics[j].scaleInvariantOffTrajectorySse /
357*38e8c45fSAndroid Build Coastguard Worker mAggregatedMetrics[j].scaleInvariantErrorsCount);
358*38e8c45fSAndroid Build Coastguard Worker
359*38e8c45fSAndroid Build Coastguard Worker ++bucket_count;
360*38e8c45fSAndroid Build Coastguard Worker }
361*38e8c45fSAndroid Build Coastguard Worker
362*38e8c45fSAndroid Build Coastguard Worker if (bucket_count > 0) {
363*38e8c45fSAndroid Build Coastguard Worker const float averageAlongTrajectoryRmse = alongTrajectoryRmseSum / bucket_count;
364*38e8c45fSAndroid Build Coastguard Worker mAtomFields.back().scaleInvariantAlongTrajectoryRmse =
365*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(averageAlongTrajectoryRmse * 1000);
366*38e8c45fSAndroid Build Coastguard Worker
367*38e8c45fSAndroid Build Coastguard Worker const float averageOffTrajectoryRmse = offTrajectoryRmseSum / bucket_count;
368*38e8c45fSAndroid Build Coastguard Worker mAtomFields.back().scaleInvariantOffTrajectoryRmse =
369*38e8c45fSAndroid Build Coastguard Worker static_cast<int>(averageOffTrajectoryRmse * 1000);
370*38e8c45fSAndroid Build Coastguard Worker }
371*38e8c45fSAndroid Build Coastguard Worker }
372*38e8c45fSAndroid Build Coastguard Worker }
373*38e8c45fSAndroid Build Coastguard Worker
reportMetrics()374*38e8c45fSAndroid Build Coastguard Worker void MotionPredictorMetricsManager::reportMetrics() {
375*38e8c45fSAndroid Build Coastguard Worker LOG_ALWAYS_FATAL_IF(!mReportAtomFunction);
376*38e8c45fSAndroid Build Coastguard Worker // Report one atom for each prediction time bucket.
377*38e8c45fSAndroid Build Coastguard Worker for (size_t i = 0; i < mAtomFields.size(); ++i) {
378*38e8c45fSAndroid Build Coastguard Worker mReportAtomFunction(mAtomFields[i]);
379*38e8c45fSAndroid Build Coastguard Worker }
380*38e8c45fSAndroid Build Coastguard Worker }
381*38e8c45fSAndroid Build Coastguard Worker
382*38e8c45fSAndroid Build Coastguard Worker } // namespace android
383