1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
6 #define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
7
8 #if defined(__ANDROID__) || defined(__ANDROID_HOST__)
9 #include "chrome_to_android_compatibility.h"
10 #endif
11
12 #include <cstdint>
13 #include <unordered_set>
14 #include <vector>
15
16 #include "base/component_export.h"
17 #include "base/optional.h"
18 #include "base/time/time.h"
19
20 namespace ui {
21
COMPONENT_EXPORT(EVDEV)22 struct COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilterModelConfig {
23 // Explicit constructor to make chromium style happy.
24 NeuralStylusPalmDetectionFilterModelConfig();
25 NeuralStylusPalmDetectionFilterModelConfig(
26 const NeuralStylusPalmDetectionFilterModelConfig& other);
27 ~NeuralStylusPalmDetectionFilterModelConfig();
28 // Number of nearest neighbors to use in vector construction.
29 uint32_t nearest_neighbor_count = 0;
30
31 // Number of biggest nearby neighbors to use in vector construction.
32 uint32_t biggest_near_neighbor_count = 0;
33
34 // Maximum distance of neighbor centroid, in millimeters.
35 float max_neighbor_distance_in_mm = 0.0f;
36
37 base::TimeDelta max_dead_neighbor_time;
38
39 // Minimum count of samples in a stroke for neural comparison.
40 uint32_t min_sample_count = 0;
41
42 // Maximum sample count.
43 uint32_t max_sample_count = 0;
44
45 // Convert the provided 'sample_count' to an equivalent time duration.
46 // Should only be called when resampling is enabled.
47 base::TimeDelta GetEquivalentDuration(uint32_t sample_count) const;
48
49 // Minimum count of samples for a stroke to be considered as a neighbor.
50 uint32_t neighbor_min_sample_count = 0;
51
52 bool include_sequence_count_in_strokes = false;
53
54 // If this number is positive, short strokes with a touch major greater than
55 // or equal to this should be marked as a palm. If 0 or less, has no effect.
56 float heuristic_palm_touch_limit = 0.0f;
57
58 // If this number is positive, short strokes with any touch having an area
59 // greater than or equal to this should be marked as a palm. If <= 0, has no
60 // effect
61 float heuristic_palm_area_limit = 0.0f;
62
63 // If true, runs the heuristic palm check on short strokes, and enables delay
64 // on them if the heuristic would have marked the touch as a palm at that
65 // point.
66 bool heuristic_delay_start_if_palm = false;
67
68 // Similar to `heuristic_delay_start_if_palm`, but uses NN model to do the
69 // early check. NN early check happens on strokes with certain sample_counts
70 // defined in `early_stage_sample_counts`.
71 bool nn_delay_start_if_palm = false;
72
73 // Maximum blank time within a session, in milliseconds.
74 // Two tracking_ids are considered in one session if they overlap with each
75 // other or the gap between them is less than max_blank_time.
76 base::TimeDelta max_blank_time;
77
78 // If true, uses tracking_id count within a session as a feature.
79 bool use_tracking_id_count = false;
80
81 // If true, uses current active tracking_id count as a feature.
82 bool use_active_tracking_id_count = false;
83
84 // The model version (e.g. "alpha" for kohaku, "beta" for redrix) to use.
85 std::string model_version;
86
87 // If empty, the radius by the device is left as is.
88 // If non empty, the radius reported by device is re-sized in features by the
89 // polynomial defined in this vector. E.g. if this vector is {0.5, 1.3,
90 // -0.2, 1.0} Each radius r is replaced by
91 //
92 // R = 0.5 * r^3 + 1.3 * r^2 - 0.2 * r + 1
93 std::vector<float> radius_polynomial_resize;
94
95 float output_threshold = 0.0f;
96
97 // If a stroke has these numbers of samples, run an early stage detection to
98 // check if it's spurious and mark it held if so.
99 std::unordered_set<uint32_t> early_stage_sample_counts;
100
101 // If set, time between values to resample. Must match the value coded into
102 // model. Currently the model is developed for 120Hz touch devices, so this
103 // value must be set to "8 ms" if your device has a different refresh rate.
104 // If not set, no resampling is done.
105 base::Optional<base::TimeDelta> resample_period;
106 };
107
108 // An abstract model utilized by NueralStylusPalmDetectionFilter.
COMPONENT_EXPORT(EVDEV)109 class COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilterModel {
110 public:
111 virtual ~NeuralStylusPalmDetectionFilterModel() {}
112
113 // Actually execute inference on floating point input. If the length of
114 // features is not correct, return Nan. The return value is assumed to be the
115 // input of a sigmoid. i.e. any value greater than 0 implies a positive
116 // result.
117 virtual float Inference(const std::vector<float>& features) const = 0;
118
119 virtual const NeuralStylusPalmDetectionFilterModelConfig& config() const = 0;
120 };
121
122 } // namespace ui
123
124 #endif // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_MODEL_H_
125