xref: /aosp_15_r20/frameworks/native/include/input/MotionPredictorMetricsManager.h (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1*38e8c45fSAndroid Build Coastguard Worker /*
2*38e8c45fSAndroid Build Coastguard Worker  * Copyright 2023 The Android Open Source Project
3*38e8c45fSAndroid Build Coastguard Worker  *
4*38e8c45fSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*38e8c45fSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*38e8c45fSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*38e8c45fSAndroid Build Coastguard Worker  *
8*38e8c45fSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*38e8c45fSAndroid Build Coastguard Worker  *
10*38e8c45fSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*38e8c45fSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*38e8c45fSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*38e8c45fSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*38e8c45fSAndroid Build Coastguard Worker  * limitations under the License.
15*38e8c45fSAndroid Build Coastguard Worker  */
16*38e8c45fSAndroid Build Coastguard Worker 
17*38e8c45fSAndroid Build Coastguard Worker #include <cstddef>
18*38e8c45fSAndroid Build Coastguard Worker #include <cstdint>
19*38e8c45fSAndroid Build Coastguard Worker #include <functional>
20*38e8c45fSAndroid Build Coastguard Worker #include <limits>
21*38e8c45fSAndroid Build Coastguard Worker #include <vector>
22*38e8c45fSAndroid Build Coastguard Worker 
23*38e8c45fSAndroid Build Coastguard Worker #include <input/Input.h> // for MotionEvent
24*38e8c45fSAndroid Build Coastguard Worker #include <input/RingBuffer.h>
25*38e8c45fSAndroid Build Coastguard Worker #include <utils/Timers.h> // for nsecs_t
26*38e8c45fSAndroid Build Coastguard Worker 
27*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Core"
28*38e8c45fSAndroid Build Coastguard Worker 
29*38e8c45fSAndroid Build Coastguard Worker namespace android {
30*38e8c45fSAndroid Build Coastguard Worker 
31*38e8c45fSAndroid Build Coastguard Worker /**
32*38e8c45fSAndroid Build Coastguard Worker  * Class to handle computing and reporting metrics for MotionPredictor.
33*38e8c45fSAndroid Build Coastguard Worker  *
34*38e8c45fSAndroid Build Coastguard Worker  * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the
35*38e8c45fSAndroid Build Coastguard Worker  * MotionEvents from the corresponding methods in MotionPredictor.
36*38e8c45fSAndroid Build Coastguard Worker  *
37*38e8c45fSAndroid Build Coastguard Worker  * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When
38*38e8c45fSAndroid Build Coastguard Worker  * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final
39*38e8c45fSAndroid Build Coastguard Worker  * AtomFields are computed and reported to the stats library. The number of atoms reported is equal
40*38e8c45fSAndroid Build Coastguard Worker  * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one
41*38e8c45fSAndroid Build Coastguard Worker  * "prediction time bucket" — the amount of time into the future being predicted.
42*38e8c45fSAndroid Build Coastguard Worker  *
43*38e8c45fSAndroid Build Coastguard Worker  * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library
44*38e8c45fSAndroid Build Coastguard Worker  * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported.
45*38e8c45fSAndroid Build Coastguard Worker  */
46*38e8c45fSAndroid Build Coastguard Worker class MotionPredictorMetricsManager {
47*38e8c45fSAndroid Build Coastguard Worker public:
48*38e8c45fSAndroid Build Coastguard Worker     struct AtomFields;
49*38e8c45fSAndroid Build Coastguard Worker 
50*38e8c45fSAndroid Build Coastguard Worker     using ReportAtomFunction = std::function<void(const AtomFields&)>;
51*38e8c45fSAndroid Build Coastguard Worker 
52*38e8c45fSAndroid Build Coastguard Worker     static void defaultReportAtomFunction(const AtomFields& atomFields);
53*38e8c45fSAndroid Build Coastguard Worker 
54*38e8c45fSAndroid Build Coastguard Worker     // Parameters:
55*38e8c45fSAndroid Build Coastguard Worker     //  • predictionInterval: the time interval between successive prediction target timestamps.
56*38e8c45fSAndroid Build Coastguard Worker     //    Note: the MetricsManager assumes that the input interval equals the prediction interval.
57*38e8c45fSAndroid Build Coastguard Worker     //  • maxNumPredictions: the maximum number of distinct target timestamps the prediction model
58*38e8c45fSAndroid Build Coastguard Worker     //    will generate predictions for. The MetricsManager reports this many atoms per stroke.
59*38e8c45fSAndroid Build Coastguard Worker     //  • [Optional] reportAtomFunction: the function that will be called to report metrics. If
60*38e8c45fSAndroid Build Coastguard Worker     //    omitted (or if an empty function is given), the `stats_write(…)` function from the Android
61*38e8c45fSAndroid Build Coastguard Worker     //    stats library will be used.
62*38e8c45fSAndroid Build Coastguard Worker     MotionPredictorMetricsManager(
63*38e8c45fSAndroid Build Coastguard Worker             nsecs_t predictionInterval,
64*38e8c45fSAndroid Build Coastguard Worker             size_t maxNumPredictions,
65*38e8c45fSAndroid Build Coastguard Worker             ReportAtomFunction reportAtomFunction = defaultReportAtomFunction);
66*38e8c45fSAndroid Build Coastguard Worker 
67*38e8c45fSAndroid Build Coastguard Worker     // This method should be called once for each call to MotionPredictor::record, receiving the
68*38e8c45fSAndroid Build Coastguard Worker     // forwarded MotionEvent argument.
69*38e8c45fSAndroid Build Coastguard Worker     void onRecord(const MotionEvent& inputEvent);
70*38e8c45fSAndroid Build Coastguard Worker 
71*38e8c45fSAndroid Build Coastguard Worker     // This method should be called once for each call to MotionPredictor::predict, receiving the
72*38e8c45fSAndroid Build Coastguard Worker     // MotionEvent that will be returned by MotionPredictor::predict.
73*38e8c45fSAndroid Build Coastguard Worker     void onPredict(const MotionEvent& predictionEvent);
74*38e8c45fSAndroid Build Coastguard Worker 
75*38e8c45fSAndroid Build Coastguard Worker     // Simple structs to hold relevant touch input information. Public so they can be used in tests.
76*38e8c45fSAndroid Build Coastguard Worker 
77*38e8c45fSAndroid Build Coastguard Worker     struct TouchPoint {
78*38e8c45fSAndroid Build Coastguard Worker         Eigen::Vector2f position; // (y, x) in pixels
79*38e8c45fSAndroid Build Coastguard Worker         float pressure;
80*38e8c45fSAndroid Build Coastguard Worker     };
81*38e8c45fSAndroid Build Coastguard Worker 
82*38e8c45fSAndroid Build Coastguard Worker     struct GroundTruthPoint : TouchPoint {
83*38e8c45fSAndroid Build Coastguard Worker         nsecs_t timestamp;
84*38e8c45fSAndroid Build Coastguard Worker     };
85*38e8c45fSAndroid Build Coastguard Worker 
86*38e8c45fSAndroid Build Coastguard Worker     struct PredictionPoint : TouchPoint {
87*38e8c45fSAndroid Build Coastguard Worker         // The timestamp of the last ground truth point when the prediction was made.
88*38e8c45fSAndroid Build Coastguard Worker         nsecs_t originTimestamp;
89*38e8c45fSAndroid Build Coastguard Worker 
90*38e8c45fSAndroid Build Coastguard Worker         nsecs_t targetTimestamp;
91*38e8c45fSAndroid Build Coastguard Worker 
92*38e8c45fSAndroid Build Coastguard Worker         // Order by targetTimestamp when sorting.
93*38e8c45fSAndroid Build Coastguard Worker         bool operator<(const PredictionPoint& other) const {
94*38e8c45fSAndroid Build Coastguard Worker             return this->targetTimestamp < other.targetTimestamp;
95*38e8c45fSAndroid Build Coastguard Worker         }
96*38e8c45fSAndroid Build Coastguard Worker     };
97*38e8c45fSAndroid Build Coastguard Worker 
98*38e8c45fSAndroid Build Coastguard Worker     // Metrics aggregated so far for the current stroke. These are not the final fields to be
99*38e8c45fSAndroid Build Coastguard Worker     // reported in the atom (see AtomFields below), but rather an intermediate representation of the
100*38e8c45fSAndroid Build Coastguard Worker     // data that can be conveniently aggregated and from which the atom fields can be derived later.
101*38e8c45fSAndroid Build Coastguard Worker     //
102*38e8c45fSAndroid Build Coastguard Worker     // Displacement units are in pixels.
103*38e8c45fSAndroid Build Coastguard Worker     //
104*38e8c45fSAndroid Build Coastguard Worker     // "Along-trajectory error" is the dot product of the prediction error with the unit vector
105*38e8c45fSAndroid Build Coastguard Worker     // pointing towards the ground truth point whose timestamp corresponds to the prediction
106*38e8c45fSAndroid Build Coastguard Worker     // target timestamp, originating from the preceding ground truth point.
107*38e8c45fSAndroid Build Coastguard Worker     //
108*38e8c45fSAndroid Build Coastguard Worker     // "Off-trajectory error" is the component of the prediction error orthogonal to the
109*38e8c45fSAndroid Build Coastguard Worker     // "along-trajectory" unit vector described above.
110*38e8c45fSAndroid Build Coastguard Worker     //
111*38e8c45fSAndroid Build Coastguard Worker     // "High-velocity" errors are errors that are only accumulated when the velocity between the
112*38e8c45fSAndroid Build Coastguard Worker     // most recent two input events exceeds a certain threshold.
113*38e8c45fSAndroid Build Coastguard Worker     //
114*38e8c45fSAndroid Build Coastguard Worker     // "Scale-invariant errors" are the errors produced when the path length of the stroke is
115*38e8c45fSAndroid Build Coastguard Worker     // scaled to 1. (In other words, the error distances are normalized by the path length.)
116*38e8c45fSAndroid Build Coastguard Worker     struct AggregatedStrokeMetrics {
117*38e8c45fSAndroid Build Coastguard Worker         // General errors
118*38e8c45fSAndroid Build Coastguard Worker         float alongTrajectoryErrorSum = 0;
119*38e8c45fSAndroid Build Coastguard Worker         float alongTrajectorySumSquaredErrors = 0;
120*38e8c45fSAndroid Build Coastguard Worker         float offTrajectorySumSquaredErrors = 0;
121*38e8c45fSAndroid Build Coastguard Worker         float pressureSumSquaredErrors = 0;
122*38e8c45fSAndroid Build Coastguard Worker         size_t generalErrorsCount = 0;
123*38e8c45fSAndroid Build Coastguard Worker 
124*38e8c45fSAndroid Build Coastguard Worker         // High-velocity errors
125*38e8c45fSAndroid Build Coastguard Worker         float highVelocityAlongTrajectorySse = 0;
126*38e8c45fSAndroid Build Coastguard Worker         float highVelocityOffTrajectorySse = 0;
127*38e8c45fSAndroid Build Coastguard Worker         size_t highVelocityErrorsCount = 0;
128*38e8c45fSAndroid Build Coastguard Worker 
129*38e8c45fSAndroid Build Coastguard Worker         // Scale-invariant errors
130*38e8c45fSAndroid Build Coastguard Worker         float scaleInvariantAlongTrajectorySse = 0;
131*38e8c45fSAndroid Build Coastguard Worker         float scaleInvariantOffTrajectorySse = 0;
132*38e8c45fSAndroid Build Coastguard Worker         size_t scaleInvariantErrorsCount = 0;
133*38e8c45fSAndroid Build Coastguard Worker     };
134*38e8c45fSAndroid Build Coastguard Worker 
135*38e8c45fSAndroid Build Coastguard Worker     // In order to explicitly indicate "no relevant data" for a metric, we report this
136*38e8c45fSAndroid Build Coastguard Worker     // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is
137*38e8c45fSAndroid Build Coastguard Worker     // completely unobtainable. For along-trajectory error mean, which can be negative, the
138*38e8c45fSAndroid Build Coastguard Worker     // magnitude makes it unobtainable in practice.)
139*38e8c45fSAndroid Build Coastguard Worker     static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min();
140*38e8c45fSAndroid Build Coastguard Worker 
141*38e8c45fSAndroid Build Coastguard Worker     // Final metric values reported in the atom.
142*38e8c45fSAndroid Build Coastguard Worker     struct AtomFields {
143*38e8c45fSAndroid Build Coastguard Worker         int deltaTimeBucketMilliseconds = 0;
144*38e8c45fSAndroid Build Coastguard Worker 
145*38e8c45fSAndroid Build Coastguard Worker         // General errors
146*38e8c45fSAndroid Build Coastguard Worker         int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL;
147*38e8c45fSAndroid Build Coastguard Worker         int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL;
148*38e8c45fSAndroid Build Coastguard Worker         int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL;
149*38e8c45fSAndroid Build Coastguard Worker         int pressureRmseMilliunits = NO_DATA_SENTINEL;
150*38e8c45fSAndroid Build Coastguard Worker 
151*38e8c45fSAndroid Build Coastguard Worker         // High-velocity errors
152*38e8c45fSAndroid Build Coastguard Worker         int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
153*38e8c45fSAndroid Build Coastguard Worker         int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
154*38e8c45fSAndroid Build Coastguard Worker 
155*38e8c45fSAndroid Build Coastguard Worker         // Scale-invariant errors
156*38e8c45fSAndroid Build Coastguard Worker         int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels
157*38e8c45fSAndroid Build Coastguard Worker         int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL;   // millipixels
158*38e8c45fSAndroid Build Coastguard Worker     };
159*38e8c45fSAndroid Build Coastguard Worker 
160*38e8c45fSAndroid Build Coastguard Worker private:
161*38e8c45fSAndroid Build Coastguard Worker     // The interval between consecutive predictions' target timestamps. We assume that the input
162*38e8c45fSAndroid Build Coastguard Worker     // interval also equals this value.
163*38e8c45fSAndroid Build Coastguard Worker     const nsecs_t mPredictionInterval;
164*38e8c45fSAndroid Build Coastguard Worker 
165*38e8c45fSAndroid Build Coastguard Worker     // The maximum number of input frames into the future the model can predict.
166*38e8c45fSAndroid Build Coastguard Worker     // Used to perform time-bucketing of metrics.
167*38e8c45fSAndroid Build Coastguard Worker     const size_t mMaxNumPredictions;
168*38e8c45fSAndroid Build Coastguard Worker 
169*38e8c45fSAndroid Build Coastguard Worker     // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant
170*38e8c45fSAndroid Build Coastguard Worker     // error. (Also, the last two points are used to compute the ground truth trajectory.)
171*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints;
172*38e8c45fSAndroid Build Coastguard Worker 
173*38e8c45fSAndroid Build Coastguard Worker     // Predictions having a targetTimestamp after the most recent ground truth point's timestamp.
174*38e8c45fSAndroid Build Coastguard Worker     // Invariant: sorted in ascending order of targetTimestamp.
175*38e8c45fSAndroid Build Coastguard Worker     std::vector<PredictionPoint> mRecentPredictions;
176*38e8c45fSAndroid Build Coastguard Worker 
177*38e8c45fSAndroid Build Coastguard Worker     // Containers for the intermediate representation of stroke metrics and the final atom fields.
178*38e8c45fSAndroid Build Coastguard Worker     // These are indexed by the number of input frames into the future being predicted minus one,
179*38e8c45fSAndroid Build Coastguard Worker     // and always have size mMaxNumPredictions.
180*38e8c45fSAndroid Build Coastguard Worker     std::vector<AggregatedStrokeMetrics> mAggregatedMetrics;
181*38e8c45fSAndroid Build Coastguard Worker     std::vector<AtomFields> mAtomFields;
182*38e8c45fSAndroid Build Coastguard Worker 
183*38e8c45fSAndroid Build Coastguard Worker     const ReportAtomFunction mReportAtomFunction;
184*38e8c45fSAndroid Build Coastguard Worker 
185*38e8c45fSAndroid Build Coastguard Worker     // Helper methods for the implementation of onRecord and onPredict.
186*38e8c45fSAndroid Build Coastguard Worker 
187*38e8c45fSAndroid Build Coastguard Worker     // Clears stored ground truth and prediction points, as well as all stored metrics for the
188*38e8c45fSAndroid Build Coastguard Worker     // current stroke.
189*38e8c45fSAndroid Build Coastguard Worker     void clearStrokeData();
190*38e8c45fSAndroid Build Coastguard Worker 
191*38e8c45fSAndroid Build Coastguard Worker     // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from
192*38e8c45fSAndroid Build Coastguard Worker     // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that
193*38e8c45fSAndroid Build Coastguard Worker     // fuzzily match with the new ground truth point.
194*38e8c45fSAndroid Build Coastguard Worker     void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint);
195*38e8c45fSAndroid Build Coastguard Worker 
196*38e8c45fSAndroid Build Coastguard Worker     // Given a new prediction with targetTimestamp matching the latest ground truth point's
197*38e8c45fSAndroid Build Coastguard Worker     // timestamp, computes the corresponding metrics and updates mAggregatedMetrics.
198*38e8c45fSAndroid Build Coastguard Worker     void updateAggregatedMetrics(const PredictionPoint& predictionPoint);
199*38e8c45fSAndroid Build Coastguard Worker 
200*38e8c45fSAndroid Build Coastguard Worker     // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics.
201*38e8c45fSAndroid Build Coastguard Worker     void computeAtomFields();
202*38e8c45fSAndroid Build Coastguard Worker 
203*38e8c45fSAndroid Build Coastguard Worker     // Reports the current data in mAtomFields by calling mReportAtomFunction.
204*38e8c45fSAndroid Build Coastguard Worker     void reportMetrics();
205*38e8c45fSAndroid Build Coastguard Worker };
206*38e8c45fSAndroid Build Coastguard Worker 
207*38e8c45fSAndroid Build Coastguard Worker } // namespace android
208