1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
6
7 #include <base/logging.h>
8 #include <algorithm>
9
10 namespace ui {
11
12 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
CreatePalmFilterDeviceInfo(const EventDeviceInfo & devinfo)13 PalmFilterDeviceInfo CreatePalmFilterDeviceInfo(
14 const EventDeviceInfo& devinfo) {
15 PalmFilterDeviceInfo info;
16
17 info.max_x = devinfo.GetAbsMaximum(ABS_MT_POSITION_X);
18 info.x_res = devinfo.GetAbsResolution(ABS_MT_POSITION_X);
19 info.max_y = devinfo.GetAbsMaximum(ABS_MT_POSITION_Y);
20 info.y_res = devinfo.GetAbsResolution(ABS_MT_POSITION_Y);
21 if (info.x_res == 0) {
22 info.x_res = 1;
23 }
24 if (info.y_res == 0) {
25 info.y_res = 1;
26 }
27
28 info.major_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MAJOR);
29 if (info.major_radius_res == 0) {
30 // Device does not report major res: set to 1.
31 info.major_radius_res = 1;
32 }
33 if (devinfo.HasAbsEvent(ABS_MT_TOUCH_MINOR)) {
34 info.minor_radius_supported = true;
35 info.minor_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MINOR);
36 } else {
37 info.minor_radius_supported = false;
38 info.minor_radius_res = info.major_radius_res;
39 }
40 if (info.minor_radius_res == 0) {
41 // Device does not report minor res: set to 1.
42 info.minor_radius_res = 1;
43 }
44
45 return info;
46 }
47 #endif
48
49 namespace {
ScaledRadius(float radius,const NeuralStylusPalmDetectionFilterModelConfig & model_config)50 float ScaledRadius(
51 float radius,
52 const NeuralStylusPalmDetectionFilterModelConfig& model_config) {
53 if (model_config.radius_polynomial_resize.empty()) {
54 return radius;
55 }
56 float return_value = 0.0f;
57 for (uint32_t i = 0; i < model_config.radius_polynomial_resize.size(); ++i) {
58 float power = model_config.radius_polynomial_resize.size() - 1 - i;
59 return_value +=
60 model_config.radius_polynomial_resize[i] * powf(radius, power);
61 }
62 return return_value;
63 }
64
interpolate(float start_value,float end_value,float proportion)65 float interpolate(float start_value, float end_value, float proportion) {
66 return start_value + (end_value - start_value) * proportion;
67 }
68
69 /**
70 * During resampling, the later events are used as a basis to populate
71 * non-resampled fields like major and minor. However, if the requested time is
72 * within this delay of the earlier event, the earlier event will be used as a
73 * basis instead.
74 */
75 const static auto kPreferInitialEventDelay =
76 base::TimeDelta::FromMicroseconds(1);
77
78 /**
79 * Interpolate between the "before" and "after" events to get a resampled value
80 * at the timestamp 'time'. Not all fields are interpolated. For fields that are
81 * not interpolated, the values are taken from the 'after' sample unless the
82 * requested time is very close to the 'before' sample.
83 */
GetSampleAtTime(base::TimeTicks time,const PalmFilterSample & before,const PalmFilterSample & after)84 PalmFilterSample GetSampleAtTime(base::TimeTicks time,
85 const PalmFilterSample& before,
86 const PalmFilterSample& after) {
87 // Use the newest sample as the base, except when the requested time is very
88 // close to the 'before' sample.
89 PalmFilterSample result = after;
90 if (time - before.time < kPreferInitialEventDelay) {
91 result = before;
92 }
93 // Only the x and y values are interpolated. We could also interpolate the
94 // oval size and orientation, but it's not a simple computation, and would
95 // likely not provide much value.
96 const float proportion =
97 static_cast<float>((time - before.time).InNanoseconds()) /
98 (after.time - before.time).InNanoseconds();
99 result.edge = interpolate(before.edge, after.edge, proportion);
100 result.point.set_x(
101 interpolate(before.point.x(), after.point.x(), proportion));
102 result.point.set_y(
103 interpolate(before.point.y(), after.point.y(), proportion));
104 result.time = time;
105 return result;
106 }
107 } // namespace
108
CreatePalmFilterSample(const InProgressTouchEvdev & touch,const base::TimeTicks & time,const NeuralStylusPalmDetectionFilterModelConfig & model_config,const PalmFilterDeviceInfo & dev_info)109 PalmFilterSample CreatePalmFilterSample(
110 const InProgressTouchEvdev& touch,
111 const base::TimeTicks& time,
112 const NeuralStylusPalmDetectionFilterModelConfig& model_config,
113 const PalmFilterDeviceInfo& dev_info) {
114 // radius_x and radius_y have been
115 // scaled by resolution already.
116
117 PalmFilterSample sample;
118 sample.time = time;
119
120 sample.major_radius = ScaledRadius(
121 std::max(touch.major, touch.minor) / dev_info.major_radius_res,
122 model_config);
123 if (dev_info.minor_radius_supported) {
124 sample.minor_radius = ScaledRadius(
125 std::min(touch.major, touch.minor) / dev_info.minor_radius_res,
126 model_config);
127 } else {
128 sample.minor_radius = ScaledRadius(touch.major, model_config);
129 }
130
131 float nearest_x_edge = std::min(touch.x, dev_info.max_x - touch.x);
132 float nearest_y_edge = std::min(touch.y, dev_info.max_y - touch.y);
133 float normalized_x_edge = nearest_x_edge / dev_info.x_res;
134 float normalized_y_edge = nearest_y_edge / dev_info.y_res;
135 // Nearest edge distance, in mm.
136 sample.edge = std::min(normalized_x_edge, normalized_y_edge);
137 sample.point =
138 gfx::PointF(touch.x / dev_info.x_res, touch.y / dev_info.y_res);
139 sample.tracking_id = touch.tracking_id;
140 sample.pressure = touch.pressure;
141
142 return sample;
143 }
144
PalmFilterStroke(const NeuralStylusPalmDetectionFilterModelConfig & model_config,int tracking_id)145 PalmFilterStroke::PalmFilterStroke(
146 const NeuralStylusPalmDetectionFilterModelConfig& model_config,
147 int tracking_id)
148 : tracking_id_(tracking_id),
149 max_sample_count_(model_config.max_sample_count),
150 resample_period_(model_config.resample_period) {}
151 PalmFilterStroke::PalmFilterStroke(const PalmFilterStroke& other) = default;
152 PalmFilterStroke::PalmFilterStroke(PalmFilterStroke&& other) = default;
~PalmFilterStroke()153 PalmFilterStroke::~PalmFilterStroke() {}
154
ProcessSample(const PalmFilterSample & sample)155 void PalmFilterStroke::ProcessSample(const PalmFilterSample& sample) {
156 DCHECK_EQ(tracking_id_, sample.tracking_id);
157 if (samples_seen_ == 0) {
158 first_sample_time_ = sample.time;
159 }
160
161 AddSample(sample);
162
163 if (resample_period_.has_value()) {
164 // Prune based on time
165 const base::TimeDelta max_duration =
166 (*resample_period_) * (max_sample_count_ - 1);
167 while (samples_.size() > 2 &&
168 samples_.back().time - samples_[1].time >= max_duration) {
169 // We can only discard the sample if after it's discarded, we still cover
170 // the entire range. If we don't, we need to keep this sample for
171 // calculating resampled values.
172 AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
173 samples_.pop_front();
174 }
175 } else {
176 // Prune based on number of samples
177 while (samples_.size() > max_sample_count_) {
178 AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
179 samples_.pop_front();
180 }
181 }
182 }
183
AddSample(const PalmFilterSample & sample)184 void PalmFilterStroke::AddSample(const PalmFilterSample& sample) {
185 AddToUnscaledCentroid(sample.point.OffsetFromOrigin());
186 samples_.push_back(sample);
187 samples_seen_++;
188 }
189
AddToUnscaledCentroid(const gfx::Vector2dF point)190 void PalmFilterStroke::AddToUnscaledCentroid(const gfx::Vector2dF point) {
191 const gfx::Vector2dF corrected_point = point - unscaled_centroid_sum_error_;
192 const gfx::PointF new_unscaled_centroid =
193 unscaled_centroid_ + corrected_point;
194 unscaled_centroid_sum_error_ =
195 (new_unscaled_centroid - unscaled_centroid_) - corrected_point;
196 unscaled_centroid_ = new_unscaled_centroid;
197 }
198
GetCentroid() const199 gfx::PointF PalmFilterStroke::GetCentroid() const {
200 if (samples_.size() == 0) {
201 return gfx::PointF(0., 0.);
202 }
203 return gfx::ScalePoint(unscaled_centroid_, 1.f / samples_.size());
204 }
205
samples() const206 const std::deque<PalmFilterSample>& PalmFilterStroke::samples() const {
207 return samples_;
208 }
209
tracking_id() const210 int PalmFilterStroke::tracking_id() const {
211 return tracking_id_;
212 }
213
Duration() const214 base::TimeDelta PalmFilterStroke::Duration() const {
215 if (samples_.empty()) {
216 LOG(DFATAL) << "No samples available";
217 return base::Milliseconds(0);
218 }
219 return samples_.back().time - first_sample_time_;
220 }
221
PreviousDuration() const222 base::TimeDelta PalmFilterStroke::PreviousDuration() const {
223 if (samples_.size() <= 1) {
224 LOG(DFATAL) << "Not enough samples";
225 return base::Milliseconds(0);
226 }
227 const PalmFilterSample& secondToLastSample = samples_.rbegin()[1];
228 return secondToLastSample.time - first_sample_time_;
229 }
230
LastSampleCrossed(base::TimeDelta duration) const231 bool PalmFilterStroke::LastSampleCrossed(base::TimeDelta duration) const {
232 if (samples_.size() <= 1) {
233 // If there's only 1 sample, stroke just started and Duration() is zero.
234 return false;
235 }
236 return PreviousDuration() < duration && duration <= Duration();
237 }
238
GetSampleAt(base::TimeTicks time) const239 PalmFilterSample PalmFilterStroke::GetSampleAt(base::TimeTicks time) const {
240 size_t i = 0;
241 for (; i < samples_.size() && samples_[i].time < time; ++i) {
242 }
243
244 if (i < samples_.size() && !samples_.empty() && samples_[i].time == time) {
245 return samples_[i];
246 }
247 if (i == 0 || i == samples_.size()) {
248 LOG(DFATAL) << "Invalid index: " << i
249 << ", can't interpolate for time: " << time;
250 return {};
251 }
252 return GetSampleAtTime(time, samples_[i - 1], samples_[i]);
253 }
254
samples_seen() const255 uint64_t PalmFilterStroke::samples_seen() const {
256 return samples_seen_;
257 }
258
MaxMajorRadius() const259 float PalmFilterStroke::MaxMajorRadius() const {
260 float maximum = 0.0;
261 for (const auto& sample : samples_) {
262 maximum = std::max(maximum, sample.major_radius);
263 }
264 return maximum;
265 }
266
BiggestSize() const267 float PalmFilterStroke::BiggestSize() const {
268 float biggest = 0;
269 for (const auto& sample : samples_) {
270 float size;
271 if (sample.minor_radius <= 0) {
272 size = sample.major_radius * sample.major_radius;
273 } else {
274 size = sample.major_radius * sample.minor_radius;
275 }
276 biggest = std::max(biggest, size);
277 }
278 return biggest;
279 }
280
addLinePrefix(std::string str,const std::string & prefix)281 static std::string addLinePrefix(std::string str, const std::string& prefix) {
282 std::stringstream ss;
283 bool newLineStarted = true;
284 for (const auto& ch : str) {
285 if (newLineStarted) {
286 ss << prefix;
287 newLineStarted = false;
288 }
289 if (ch == '\n') {
290 newLineStarted = true;
291 }
292 ss << ch;
293 }
294 return ss.str();
295 }
296
operator <<(std::ostream & out,const gfx::PointF & point)297 std::ostream& operator<<(std::ostream& out, const gfx::PointF& point) {
298 out << "PointF(" << point.x() << ", " << point.y() << ")";
299 return out;
300 }
301
operator <<(std::ostream & out,const gfx::Vector2dF & vec)302 std::ostream& operator<<(std::ostream& out, const gfx::Vector2dF& vec) {
303 out << "Vector2dF(" << vec.x() << ", " << vec.y() << ")";
304 return out;
305 }
306
operator <<(std::ostream & out,const PalmFilterDeviceInfo & info)307 std::ostream& operator<<(std::ostream& out, const PalmFilterDeviceInfo& info) {
308 out << "PalmFilterDeviceInfo(max_x=" << info.max_x;
309 out << ", max_y=" << info.max_y;
310 out << ", x_res=" << info.x_res;
311 out << ", y_res=" << info.y_res;
312 out << ", major_radius_res=" << info.major_radius_res;
313 out << ", minor_radius_res=" << info.minor_radius_res;
314 out << ", minor_radius_supported=" << info.minor_radius_supported;
315 out << ")";
316 return out;
317 }
318
operator <<(std::ostream & out,const PalmFilterSample & sample)319 std::ostream& operator<<(std::ostream& out, const PalmFilterSample& sample) {
320 out << "PalmFilterSample(major=" << sample.major_radius
321 << ", minor=" << sample.minor_radius << ", pressure=" << sample.pressure
322 << ", edge=" << sample.edge << ", tracking_id=" << sample.tracking_id
323 << ", point=" << sample.point << ", time=" << sample.time << ")";
324 return out;
325 }
326
operator <<(std::ostream & out,const PalmFilterStroke & stroke)327 std::ostream& operator<<(std::ostream& out, const PalmFilterStroke& stroke) {
328 out << "PalmFilterStroke(\n";
329 out << " GetCentroid() = " << stroke.GetCentroid() << "\n";
330 out << " BiggestSize() = " << stroke.BiggestSize() << "\n";
331 out << " MaxMajorRadius() = " << stroke.MaxMajorRadius() << "\n";
332 std::stringstream stream;
333 stream << stroke.samples();
334 out << " samples (" << stroke.samples().size() << " total): \n"
335 << addLinePrefix(stream.str(), " ") << "\n";
336 out << " samples_seen() = " << stroke.samples_seen() << "\n";
337 out << " tracking_id() = " << stroke.tracking_id() << "\n";
338 out << " max_sample_count_ = " << stroke.max_sample_count_ << "\n";
339 if (stroke.resample_period_) {
340 out << " resample_period_ = " << *(stroke.resample_period_) << "\n";
341 } else {
342 out << " resample_period_ = <not set>\n";
343 }
344 out << " first_sample_time_ = " << stroke.first_sample_time_ << "\n";
345 out << " unscaled_centroid_ = " << stroke.unscaled_centroid_ << "\n";
346 out << " unscaled_centroid_sum_error_ = "
347 << stroke.unscaled_centroid_sum_error_ << "\n";
348 out << ")\n";
349 return out;
350 }
351
352 } // namespace ui
353