1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
6 #define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
7 
8 #if defined(__ANDROID__) || defined(__ANDROID_HOST__)
9 #include "chrome_to_android_compatibility.h"
10 #endif
11 
12 #include <cstdint>
13 #include <unordered_set>
14 #include <vector>
15 
16 #include "base/component_export.h"
17 #include "base/optional.h"
18 #include "base/time/time.h"
19 
20 namespace ui {
21 
COMPONENT_EXPORT(EVDEV)22 struct COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilterModelConfig {
23   // Explicit constructor to make chromium style happy.
24   NeuralStylusPalmDetectionFilterModelConfig();
25   NeuralStylusPalmDetectionFilterModelConfig(
26       const NeuralStylusPalmDetectionFilterModelConfig& other);
27   ~NeuralStylusPalmDetectionFilterModelConfig();
28   // Number of nearest neighbors to use in vector construction.
29   uint32_t nearest_neighbor_count = 0;
30 
31   // Number of biggest nearby neighbors to use in vector construction.
32   uint32_t biggest_near_neighbor_count = 0;
33 
34   // Maximum distance of neighbor centroid, in millimeters.
35   float max_neighbor_distance_in_mm = 0.0f;
36 
37   base::TimeDelta max_dead_neighbor_time;
38 
39   // Minimum count of samples in a stroke for neural comparison.
40   uint32_t min_sample_count = 0;
41 
42   // Maximum sample count.
43   uint32_t max_sample_count = 0;
44 
45   // Convert the provided 'sample_count' to an equivalent time duration.
46   // Should only be called when resampling is enabled.
47   base::TimeDelta GetEquivalentDuration(uint32_t sample_count) const;
48 
49   // Minimum count of samples for a stroke to be considered as a neighbor.
50   uint32_t neighbor_min_sample_count = 0;
51 
52   bool include_sequence_count_in_strokes = false;
53 
54   // If this number is positive, short strokes with a touch major greater than
55   // or equal to this should be marked as a palm. If 0 or less, has no effect.
56   float heuristic_palm_touch_limit = 0.0f;
57 
58   // If this number is positive, short strokes with any touch having an area
59   // greater than or equal to this should be marked as a palm. If <= 0, has no
60   // effect
61   float heuristic_palm_area_limit = 0.0f;
62 
63   // If true, runs the heuristic palm check on short strokes, and enables delay
64   // on them if the heuristic would have marked the touch as a palm at that
65   // point.
66   bool heuristic_delay_start_if_palm = false;
67 
68   // Similar to `heuristic_delay_start_if_palm`, but uses NN model to do the
69   // early check. NN early check happens on strokes with certain sample_counts
70   // defined in `early_stage_sample_counts`.
71   bool nn_delay_start_if_palm = false;
72 
73   // Maximum blank time within a session, in milliseconds.
74   // Two tracking_ids are considered in one session if they overlap with each
75   // other or the gap between them is less than max_blank_time.
76   base::TimeDelta max_blank_time;
77 
78   // If true, uses tracking_id count within a session as a feature.
79   bool use_tracking_id_count = false;
80 
81   // If true, uses current active tracking_id count as a feature.
82   bool use_active_tracking_id_count = false;
83 
84   // The model version (e.g. "alpha" for kohaku, "beta" for redrix) to use.
85   std::string model_version;
86 
87   // If empty, the radius by the device is left as is.
88   // If non empty, the radius reported by device is re-sized in features by the
89   // polynomial defined in this vector. E.g. if this vector is {0.5, 1.3,
90   // -0.2, 1.0} Each radius r is replaced by
91   //
92   // R = 0.5 * r^3 + 1.3 * r^2 - 0.2 * r + 1
93   std::vector<float> radius_polynomial_resize;
94 
95   float output_threshold = 0.0f;
96 
97   // If a stroke has these numbers of samples, run an early stage detection to
98   // check if it's spurious and mark it held if so.
99   std::unordered_set<uint32_t> early_stage_sample_counts;
100 
101   // If set, time between values to resample. Must match the value coded into
102   // model. Currently the model is developed for 120Hz touch devices, so this
103   // value must be set to "8 ms" if your device has a different refresh rate.
104   // If not set, no resampling is done.
105   base::Optional<base::TimeDelta> resample_period;
106 };
107 
108 // An abstract model utilized by NueralStylusPalmDetectionFilter.
COMPONENT_EXPORT(EVDEV)109 class COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilterModel {
110  public:
111   virtual ~NeuralStylusPalmDetectionFilterModel() {}
112 
113   // Actually execute inference on floating point input. If the length of
114   // features is not correct, return Nan. The return value is assumed to be the
115   // input of a sigmoid. i.e. any value greater than 0 implies a positive
116   // result.
117   virtual float Inference(const std::vector<float>& features) const = 0;
118 
119   virtual const NeuralStylusPalmDetectionFilterModelConfig& config() const = 0;
120 };
121 
122 }  // namespace ui
123 
124 #endif  // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
125