1*38e8c45fSAndroid Build Coastguard Worker /* 2*38e8c45fSAndroid Build Coastguard Worker * Copyright 2023 The Android Open Source Project 3*38e8c45fSAndroid Build Coastguard Worker * 4*38e8c45fSAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*38e8c45fSAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*38e8c45fSAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*38e8c45fSAndroid Build Coastguard Worker * 8*38e8c45fSAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*38e8c45fSAndroid Build Coastguard Worker * 10*38e8c45fSAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*38e8c45fSAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*38e8c45fSAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*38e8c45fSAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*38e8c45fSAndroid Build Coastguard Worker * limitations under the License. 15*38e8c45fSAndroid Build Coastguard Worker */ 16*38e8c45fSAndroid Build Coastguard Worker 17*38e8c45fSAndroid Build Coastguard Worker #include <cstddef> 18*38e8c45fSAndroid Build Coastguard Worker #include <cstdint> 19*38e8c45fSAndroid Build Coastguard Worker #include <functional> 20*38e8c45fSAndroid Build Coastguard Worker #include <limits> 21*38e8c45fSAndroid Build Coastguard Worker #include <vector> 22*38e8c45fSAndroid Build Coastguard Worker 23*38e8c45fSAndroid Build Coastguard Worker #include <input/Input.h> // for MotionEvent 24*38e8c45fSAndroid Build Coastguard Worker #include <input/RingBuffer.h> 25*38e8c45fSAndroid Build Coastguard Worker #include <utils/Timers.h> // for nsecs_t 26*38e8c45fSAndroid Build Coastguard Worker 27*38e8c45fSAndroid Build Coastguard Worker #include "Eigen/Core" 28*38e8c45fSAndroid Build Coastguard Worker 29*38e8c45fSAndroid Build Coastguard Worker namespace android { 30*38e8c45fSAndroid Build Coastguard Worker 31*38e8c45fSAndroid Build Coastguard Worker /** 32*38e8c45fSAndroid Build Coastguard Worker * Class to handle computing and reporting metrics for MotionPredictor. 33*38e8c45fSAndroid Build Coastguard Worker * 34*38e8c45fSAndroid Build Coastguard Worker * The public API provides two methods: `onRecord` and `onPredict`, which expect to receive the 35*38e8c45fSAndroid Build Coastguard Worker * MotionEvents from the corresponding methods in MotionPredictor. 36*38e8c45fSAndroid Build Coastguard Worker * 37*38e8c45fSAndroid Build Coastguard Worker * This class stores AggregatedStrokeMetrics, updating them as new MotionEvents are passed in. When 38*38e8c45fSAndroid Build Coastguard Worker * onRecord receives an UP or CANCEL event, this indicates the end of the stroke, and the final 39*38e8c45fSAndroid Build Coastguard Worker * AtomFields are computed and reported to the stats library. The number of atoms reported is equal 40*38e8c45fSAndroid Build Coastguard Worker * to the value of `maxNumPredictions` passed to the constructor. Each atom corresponds to one 41*38e8c45fSAndroid Build Coastguard Worker * "prediction time bucket" — the amount of time into the future being predicted. 42*38e8c45fSAndroid Build Coastguard Worker * 43*38e8c45fSAndroid Build Coastguard Worker * If mMockLoggedAtomFields is set, the batch of AtomFields that are reported to the stats library 44*38e8c45fSAndroid Build Coastguard Worker * for one stroke are also stored in mMockLoggedAtomFields at the time they're reported. 45*38e8c45fSAndroid Build Coastguard Worker */ 46*38e8c45fSAndroid Build Coastguard Worker class MotionPredictorMetricsManager { 47*38e8c45fSAndroid Build Coastguard Worker public: 48*38e8c45fSAndroid Build Coastguard Worker struct AtomFields; 49*38e8c45fSAndroid Build Coastguard Worker 50*38e8c45fSAndroid Build Coastguard Worker using ReportAtomFunction = std::function<void(const AtomFields&)>; 51*38e8c45fSAndroid Build Coastguard Worker 52*38e8c45fSAndroid Build Coastguard Worker static void defaultReportAtomFunction(const AtomFields& atomFields); 53*38e8c45fSAndroid Build Coastguard Worker 54*38e8c45fSAndroid Build Coastguard Worker // Parameters: 55*38e8c45fSAndroid Build Coastguard Worker // • predictionInterval: the time interval between successive prediction target timestamps. 56*38e8c45fSAndroid Build Coastguard Worker // Note: the MetricsManager assumes that the input interval equals the prediction interval. 57*38e8c45fSAndroid Build Coastguard Worker // • maxNumPredictions: the maximum number of distinct target timestamps the prediction model 58*38e8c45fSAndroid Build Coastguard Worker // will generate predictions for. The MetricsManager reports this many atoms per stroke. 59*38e8c45fSAndroid Build Coastguard Worker // • [Optional] reportAtomFunction: the function that will be called to report metrics. If 60*38e8c45fSAndroid Build Coastguard Worker // omitted (or if an empty function is given), the `stats_write(…)` function from the Android 61*38e8c45fSAndroid Build Coastguard Worker // stats library will be used. 62*38e8c45fSAndroid Build Coastguard Worker MotionPredictorMetricsManager( 63*38e8c45fSAndroid Build Coastguard Worker nsecs_t predictionInterval, 64*38e8c45fSAndroid Build Coastguard Worker size_t maxNumPredictions, 65*38e8c45fSAndroid Build Coastguard Worker ReportAtomFunction reportAtomFunction = defaultReportAtomFunction); 66*38e8c45fSAndroid Build Coastguard Worker 67*38e8c45fSAndroid Build Coastguard Worker // This method should be called once for each call to MotionPredictor::record, receiving the 68*38e8c45fSAndroid Build Coastguard Worker // forwarded MotionEvent argument. 69*38e8c45fSAndroid Build Coastguard Worker void onRecord(const MotionEvent& inputEvent); 70*38e8c45fSAndroid Build Coastguard Worker 71*38e8c45fSAndroid Build Coastguard Worker // This method should be called once for each call to MotionPredictor::predict, receiving the 72*38e8c45fSAndroid Build Coastguard Worker // MotionEvent that will be returned by MotionPredictor::predict. 73*38e8c45fSAndroid Build Coastguard Worker void onPredict(const MotionEvent& predictionEvent); 74*38e8c45fSAndroid Build Coastguard Worker 75*38e8c45fSAndroid Build Coastguard Worker // Simple structs to hold relevant touch input information. Public so they can be used in tests. 76*38e8c45fSAndroid Build Coastguard Worker 77*38e8c45fSAndroid Build Coastguard Worker struct TouchPoint { 78*38e8c45fSAndroid Build Coastguard Worker Eigen::Vector2f position; // (y, x) in pixels 79*38e8c45fSAndroid Build Coastguard Worker float pressure; 80*38e8c45fSAndroid Build Coastguard Worker }; 81*38e8c45fSAndroid Build Coastguard Worker 82*38e8c45fSAndroid Build Coastguard Worker struct GroundTruthPoint : TouchPoint { 83*38e8c45fSAndroid Build Coastguard Worker nsecs_t timestamp; 84*38e8c45fSAndroid Build Coastguard Worker }; 85*38e8c45fSAndroid Build Coastguard Worker 86*38e8c45fSAndroid Build Coastguard Worker struct PredictionPoint : TouchPoint { 87*38e8c45fSAndroid Build Coastguard Worker // The timestamp of the last ground truth point when the prediction was made. 88*38e8c45fSAndroid Build Coastguard Worker nsecs_t originTimestamp; 89*38e8c45fSAndroid Build Coastguard Worker 90*38e8c45fSAndroid Build Coastguard Worker nsecs_t targetTimestamp; 91*38e8c45fSAndroid Build Coastguard Worker 92*38e8c45fSAndroid Build Coastguard Worker // Order by targetTimestamp when sorting. 93*38e8c45fSAndroid Build Coastguard Worker bool operator<(const PredictionPoint& other) const { 94*38e8c45fSAndroid Build Coastguard Worker return this->targetTimestamp < other.targetTimestamp; 95*38e8c45fSAndroid Build Coastguard Worker } 96*38e8c45fSAndroid Build Coastguard Worker }; 97*38e8c45fSAndroid Build Coastguard Worker 98*38e8c45fSAndroid Build Coastguard Worker // Metrics aggregated so far for the current stroke. These are not the final fields to be 99*38e8c45fSAndroid Build Coastguard Worker // reported in the atom (see AtomFields below), but rather an intermediate representation of the 100*38e8c45fSAndroid Build Coastguard Worker // data that can be conveniently aggregated and from which the atom fields can be derived later. 101*38e8c45fSAndroid Build Coastguard Worker // 102*38e8c45fSAndroid Build Coastguard Worker // Displacement units are in pixels. 103*38e8c45fSAndroid Build Coastguard Worker // 104*38e8c45fSAndroid Build Coastguard Worker // "Along-trajectory error" is the dot product of the prediction error with the unit vector 105*38e8c45fSAndroid Build Coastguard Worker // pointing towards the ground truth point whose timestamp corresponds to the prediction 106*38e8c45fSAndroid Build Coastguard Worker // target timestamp, originating from the preceding ground truth point. 107*38e8c45fSAndroid Build Coastguard Worker // 108*38e8c45fSAndroid Build Coastguard Worker // "Off-trajectory error" is the component of the prediction error orthogonal to the 109*38e8c45fSAndroid Build Coastguard Worker // "along-trajectory" unit vector described above. 110*38e8c45fSAndroid Build Coastguard Worker // 111*38e8c45fSAndroid Build Coastguard Worker // "High-velocity" errors are errors that are only accumulated when the velocity between the 112*38e8c45fSAndroid Build Coastguard Worker // most recent two input events exceeds a certain threshold. 113*38e8c45fSAndroid Build Coastguard Worker // 114*38e8c45fSAndroid Build Coastguard Worker // "Scale-invariant errors" are the errors produced when the path length of the stroke is 115*38e8c45fSAndroid Build Coastguard Worker // scaled to 1. (In other words, the error distances are normalized by the path length.) 116*38e8c45fSAndroid Build Coastguard Worker struct AggregatedStrokeMetrics { 117*38e8c45fSAndroid Build Coastguard Worker // General errors 118*38e8c45fSAndroid Build Coastguard Worker float alongTrajectoryErrorSum = 0; 119*38e8c45fSAndroid Build Coastguard Worker float alongTrajectorySumSquaredErrors = 0; 120*38e8c45fSAndroid Build Coastguard Worker float offTrajectorySumSquaredErrors = 0; 121*38e8c45fSAndroid Build Coastguard Worker float pressureSumSquaredErrors = 0; 122*38e8c45fSAndroid Build Coastguard Worker size_t generalErrorsCount = 0; 123*38e8c45fSAndroid Build Coastguard Worker 124*38e8c45fSAndroid Build Coastguard Worker // High-velocity errors 125*38e8c45fSAndroid Build Coastguard Worker float highVelocityAlongTrajectorySse = 0; 126*38e8c45fSAndroid Build Coastguard Worker float highVelocityOffTrajectorySse = 0; 127*38e8c45fSAndroid Build Coastguard Worker size_t highVelocityErrorsCount = 0; 128*38e8c45fSAndroid Build Coastguard Worker 129*38e8c45fSAndroid Build Coastguard Worker // Scale-invariant errors 130*38e8c45fSAndroid Build Coastguard Worker float scaleInvariantAlongTrajectorySse = 0; 131*38e8c45fSAndroid Build Coastguard Worker float scaleInvariantOffTrajectorySse = 0; 132*38e8c45fSAndroid Build Coastguard Worker size_t scaleInvariantErrorsCount = 0; 133*38e8c45fSAndroid Build Coastguard Worker }; 134*38e8c45fSAndroid Build Coastguard Worker 135*38e8c45fSAndroid Build Coastguard Worker // In order to explicitly indicate "no relevant data" for a metric, we report this 136*38e8c45fSAndroid Build Coastguard Worker // large-magnitude negative sentinel value. (Most metrics are non-negative, so this value is 137*38e8c45fSAndroid Build Coastguard Worker // completely unobtainable. For along-trajectory error mean, which can be negative, the 138*38e8c45fSAndroid Build Coastguard Worker // magnitude makes it unobtainable in practice.) 139*38e8c45fSAndroid Build Coastguard Worker static const int NO_DATA_SENTINEL = std::numeric_limits<int32_t>::min(); 140*38e8c45fSAndroid Build Coastguard Worker 141*38e8c45fSAndroid Build Coastguard Worker // Final metric values reported in the atom. 142*38e8c45fSAndroid Build Coastguard Worker struct AtomFields { 143*38e8c45fSAndroid Build Coastguard Worker int deltaTimeBucketMilliseconds = 0; 144*38e8c45fSAndroid Build Coastguard Worker 145*38e8c45fSAndroid Build Coastguard Worker // General errors 146*38e8c45fSAndroid Build Coastguard Worker int alongTrajectoryErrorMeanMillipixels = NO_DATA_SENTINEL; 147*38e8c45fSAndroid Build Coastguard Worker int alongTrajectoryErrorStdMillipixels = NO_DATA_SENTINEL; 148*38e8c45fSAndroid Build Coastguard Worker int offTrajectoryRmseMillipixels = NO_DATA_SENTINEL; 149*38e8c45fSAndroid Build Coastguard Worker int pressureRmseMilliunits = NO_DATA_SENTINEL; 150*38e8c45fSAndroid Build Coastguard Worker 151*38e8c45fSAndroid Build Coastguard Worker // High-velocity errors 152*38e8c45fSAndroid Build Coastguard Worker int highVelocityAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 153*38e8c45fSAndroid Build Coastguard Worker int highVelocityOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 154*38e8c45fSAndroid Build Coastguard Worker 155*38e8c45fSAndroid Build Coastguard Worker // Scale-invariant errors 156*38e8c45fSAndroid Build Coastguard Worker int scaleInvariantAlongTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 157*38e8c45fSAndroid Build Coastguard Worker int scaleInvariantOffTrajectoryRmse = NO_DATA_SENTINEL; // millipixels 158*38e8c45fSAndroid Build Coastguard Worker }; 159*38e8c45fSAndroid Build Coastguard Worker 160*38e8c45fSAndroid Build Coastguard Worker private: 161*38e8c45fSAndroid Build Coastguard Worker // The interval between consecutive predictions' target timestamps. We assume that the input 162*38e8c45fSAndroid Build Coastguard Worker // interval also equals this value. 163*38e8c45fSAndroid Build Coastguard Worker const nsecs_t mPredictionInterval; 164*38e8c45fSAndroid Build Coastguard Worker 165*38e8c45fSAndroid Build Coastguard Worker // The maximum number of input frames into the future the model can predict. 166*38e8c45fSAndroid Build Coastguard Worker // Used to perform time-bucketing of metrics. 167*38e8c45fSAndroid Build Coastguard Worker const size_t mMaxNumPredictions; 168*38e8c45fSAndroid Build Coastguard Worker 169*38e8c45fSAndroid Build Coastguard Worker // History of mMaxNumPredictions + 1 ground truth points, used to compute scale-invariant 170*38e8c45fSAndroid Build Coastguard Worker // error. (Also, the last two points are used to compute the ground truth trajectory.) 171*38e8c45fSAndroid Build Coastguard Worker RingBuffer<GroundTruthPoint> mRecentGroundTruthPoints; 172*38e8c45fSAndroid Build Coastguard Worker 173*38e8c45fSAndroid Build Coastguard Worker // Predictions having a targetTimestamp after the most recent ground truth point's timestamp. 174*38e8c45fSAndroid Build Coastguard Worker // Invariant: sorted in ascending order of targetTimestamp. 175*38e8c45fSAndroid Build Coastguard Worker std::vector<PredictionPoint> mRecentPredictions; 176*38e8c45fSAndroid Build Coastguard Worker 177*38e8c45fSAndroid Build Coastguard Worker // Containers for the intermediate representation of stroke metrics and the final atom fields. 178*38e8c45fSAndroid Build Coastguard Worker // These are indexed by the number of input frames into the future being predicted minus one, 179*38e8c45fSAndroid Build Coastguard Worker // and always have size mMaxNumPredictions. 180*38e8c45fSAndroid Build Coastguard Worker std::vector<AggregatedStrokeMetrics> mAggregatedMetrics; 181*38e8c45fSAndroid Build Coastguard Worker std::vector<AtomFields> mAtomFields; 182*38e8c45fSAndroid Build Coastguard Worker 183*38e8c45fSAndroid Build Coastguard Worker const ReportAtomFunction mReportAtomFunction; 184*38e8c45fSAndroid Build Coastguard Worker 185*38e8c45fSAndroid Build Coastguard Worker // Helper methods for the implementation of onRecord and onPredict. 186*38e8c45fSAndroid Build Coastguard Worker 187*38e8c45fSAndroid Build Coastguard Worker // Clears stored ground truth and prediction points, as well as all stored metrics for the 188*38e8c45fSAndroid Build Coastguard Worker // current stroke. 189*38e8c45fSAndroid Build Coastguard Worker void clearStrokeData(); 190*38e8c45fSAndroid Build Coastguard Worker 191*38e8c45fSAndroid Build Coastguard Worker // Adds the new ground truth point to mRecentGroundTruths, removes outdated predictions from 192*38e8c45fSAndroid Build Coastguard Worker // mRecentPredictions, and updates the aggregated metrics to include the recent predictions that 193*38e8c45fSAndroid Build Coastguard Worker // fuzzily match with the new ground truth point. 194*38e8c45fSAndroid Build Coastguard Worker void incorporateNewGroundTruth(const GroundTruthPoint& groundTruthPoint); 195*38e8c45fSAndroid Build Coastguard Worker 196*38e8c45fSAndroid Build Coastguard Worker // Given a new prediction with targetTimestamp matching the latest ground truth point's 197*38e8c45fSAndroid Build Coastguard Worker // timestamp, computes the corresponding metrics and updates mAggregatedMetrics. 198*38e8c45fSAndroid Build Coastguard Worker void updateAggregatedMetrics(const PredictionPoint& predictionPoint); 199*38e8c45fSAndroid Build Coastguard Worker 200*38e8c45fSAndroid Build Coastguard Worker // Computes the atom fields to mAtomFields from the values in mAggregatedMetrics. 201*38e8c45fSAndroid Build Coastguard Worker void computeAtomFields(); 202*38e8c45fSAndroid Build Coastguard Worker 203*38e8c45fSAndroid Build Coastguard Worker // Reports the current data in mAtomFields by calling mReportAtomFunction. 204*38e8c45fSAndroid Build Coastguard Worker void reportMetrics(); 205*38e8c45fSAndroid Build Coastguard Worker }; 206*38e8c45fSAndroid Build Coastguard Worker 207*38e8c45fSAndroid Build Coastguard Worker } // namespace android 208