1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
6 
7 #include <base/logging.h>
8 #include <algorithm>
9 
10 namespace ui {
11 
12 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
CreatePalmFilterDeviceInfo(const EventDeviceInfo & devinfo)13 PalmFilterDeviceInfo CreatePalmFilterDeviceInfo(
14     const EventDeviceInfo& devinfo) {
15   PalmFilterDeviceInfo info;
16 
17   info.max_x = devinfo.GetAbsMaximum(ABS_MT_POSITION_X);
18   info.x_res = devinfo.GetAbsResolution(ABS_MT_POSITION_X);
19   info.max_y = devinfo.GetAbsMaximum(ABS_MT_POSITION_Y);
20   info.y_res = devinfo.GetAbsResolution(ABS_MT_POSITION_Y);
21   if (info.x_res == 0) {
22     info.x_res = 1;
23   }
24   if (info.y_res == 0) {
25     info.y_res = 1;
26   }
27 
28   info.major_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MAJOR);
29   if (info.major_radius_res == 0) {
30     // Device does not report major res: set to 1.
31     info.major_radius_res = 1;
32   }
33   if (devinfo.HasAbsEvent(ABS_MT_TOUCH_MINOR)) {
34     info.minor_radius_supported = true;
35     info.minor_radius_res = devinfo.GetAbsResolution(ABS_MT_TOUCH_MINOR);
36   } else {
37     info.minor_radius_supported = false;
38     info.minor_radius_res = info.major_radius_res;
39   }
40   if (info.minor_radius_res == 0) {
41     // Device does not report minor res: set to 1.
42     info.minor_radius_res = 1;
43   }
44 
45   return info;
46 }
47 #endif
48 
49 namespace {
ScaledRadius(float radius,const NeuralStylusPalmDetectionFilterModelConfig & model_config)50 float ScaledRadius(
51     float radius,
52     const NeuralStylusPalmDetectionFilterModelConfig& model_config) {
53   if (model_config.radius_polynomial_resize.empty()) {
54     return radius;
55   }
56   float return_value = 0.0f;
57   for (uint32_t i = 0; i < model_config.radius_polynomial_resize.size(); ++i) {
58     float power = model_config.radius_polynomial_resize.size() - 1 - i;
59     return_value +=
60         model_config.radius_polynomial_resize[i] * powf(radius, power);
61   }
62   return return_value;
63 }
64 
interpolate(float start_value,float end_value,float proportion)65 float interpolate(float start_value, float end_value, float proportion) {
66   return start_value + (end_value - start_value) * proportion;
67 }
68 
69 /**
70  * During resampling, the later events are used as a basis to populate
71  * non-resampled fields like major and minor. However, if the requested time is
72  * within this delay of the earlier event, the earlier event will be used as a
73  * basis instead.
74  */
75 const static auto kPreferInitialEventDelay =
76     base::TimeDelta::FromMicroseconds(1);
77 
78 /**
79  * Interpolate between the "before" and "after" events to get a resampled value
80  * at the timestamp 'time'. Not all fields are interpolated. For fields that are
81  * not interpolated, the values are taken from the 'after' sample unless the
82  * requested time is very close to the 'before' sample.
83  */
GetSampleAtTime(base::TimeTicks time,const PalmFilterSample & before,const PalmFilterSample & after)84 PalmFilterSample GetSampleAtTime(base::TimeTicks time,
85                                  const PalmFilterSample& before,
86                                  const PalmFilterSample& after) {
87   // Use the newest sample as the base, except when the requested time is very
88   // close to the 'before' sample.
89   PalmFilterSample result = after;
90   if (time - before.time < kPreferInitialEventDelay) {
91     result = before;
92   }
93   // Only the x and y values are interpolated. We could also interpolate the
94   // oval size and orientation, but it's not a simple computation, and would
95   // likely not provide much value.
96   const float proportion =
97       static_cast<float>((time - before.time).InNanoseconds()) /
98       (after.time - before.time).InNanoseconds();
99   result.edge = interpolate(before.edge, after.edge, proportion);
100   result.point.set_x(
101       interpolate(before.point.x(), after.point.x(), proportion));
102   result.point.set_y(
103       interpolate(before.point.y(), after.point.y(), proportion));
104   result.time = time;
105   return result;
106 }
107 }  // namespace
108 
CreatePalmFilterSample(const InProgressTouchEvdev & touch,const base::TimeTicks & time,const NeuralStylusPalmDetectionFilterModelConfig & model_config,const PalmFilterDeviceInfo & dev_info)109 PalmFilterSample CreatePalmFilterSample(
110     const InProgressTouchEvdev& touch,
111     const base::TimeTicks& time,
112     const NeuralStylusPalmDetectionFilterModelConfig& model_config,
113     const PalmFilterDeviceInfo& dev_info) {
114   // radius_x and radius_y have been
115   // scaled by resolution already.
116 
117   PalmFilterSample sample;
118   sample.time = time;
119 
120   sample.major_radius = ScaledRadius(
121       std::max(touch.major, touch.minor) / dev_info.major_radius_res,
122       model_config);
123   if (dev_info.minor_radius_supported) {
124     sample.minor_radius = ScaledRadius(
125         std::min(touch.major, touch.minor) / dev_info.minor_radius_res,
126         model_config);
127   } else {
128     sample.minor_radius = ScaledRadius(touch.major, model_config);
129   }
130 
131   float nearest_x_edge = std::min(touch.x, dev_info.max_x - touch.x);
132   float nearest_y_edge = std::min(touch.y, dev_info.max_y - touch.y);
133   float normalized_x_edge = nearest_x_edge / dev_info.x_res;
134   float normalized_y_edge = nearest_y_edge / dev_info.y_res;
135   // Nearest edge distance, in mm.
136   sample.edge = std::min(normalized_x_edge, normalized_y_edge);
137   sample.point =
138       gfx::PointF(touch.x / dev_info.x_res, touch.y / dev_info.y_res);
139   sample.tracking_id = touch.tracking_id;
140   sample.pressure = touch.pressure;
141 
142   return sample;
143 }
144 
PalmFilterStroke(const NeuralStylusPalmDetectionFilterModelConfig & model_config,int tracking_id)145 PalmFilterStroke::PalmFilterStroke(
146     const NeuralStylusPalmDetectionFilterModelConfig& model_config,
147     int tracking_id)
148     : tracking_id_(tracking_id),
149       max_sample_count_(model_config.max_sample_count),
150       resample_period_(model_config.resample_period) {}
151 PalmFilterStroke::PalmFilterStroke(const PalmFilterStroke& other) = default;
152 PalmFilterStroke::PalmFilterStroke(PalmFilterStroke&& other) = default;
~PalmFilterStroke()153 PalmFilterStroke::~PalmFilterStroke() {}
154 
ProcessSample(const PalmFilterSample & sample)155 void PalmFilterStroke::ProcessSample(const PalmFilterSample& sample) {
156   DCHECK_EQ(tracking_id_, sample.tracking_id);
157   if (samples_seen_ == 0) {
158     first_sample_time_ = sample.time;
159   }
160 
161   AddSample(sample);
162 
163   if (resample_period_.has_value()) {
164     // Prune based on time
165     const base::TimeDelta max_duration =
166         (*resample_period_) * (max_sample_count_ - 1);
167     while (samples_.size() > 2 &&
168            samples_.back().time - samples_[1].time >= max_duration) {
169       // We can only discard the sample if after it's discarded, we still cover
170       // the entire range. If we don't, we need to keep this sample for
171       // calculating resampled values.
172       AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
173       samples_.pop_front();
174     }
175   } else {
176     // Prune based on number of samples
177     while (samples_.size() > max_sample_count_) {
178       AddToUnscaledCentroid(-samples_.front().point.OffsetFromOrigin());
179       samples_.pop_front();
180     }
181   }
182 }
183 
AddSample(const PalmFilterSample & sample)184 void PalmFilterStroke::AddSample(const PalmFilterSample& sample) {
185   AddToUnscaledCentroid(sample.point.OffsetFromOrigin());
186   samples_.push_back(sample);
187   samples_seen_++;
188 }
189 
AddToUnscaledCentroid(const gfx::Vector2dF point)190 void PalmFilterStroke::AddToUnscaledCentroid(const gfx::Vector2dF point) {
191   const gfx::Vector2dF corrected_point = point - unscaled_centroid_sum_error_;
192   const gfx::PointF new_unscaled_centroid =
193       unscaled_centroid_ + corrected_point;
194   unscaled_centroid_sum_error_ =
195       (new_unscaled_centroid - unscaled_centroid_) - corrected_point;
196   unscaled_centroid_ = new_unscaled_centroid;
197 }
198 
GetCentroid() const199 gfx::PointF PalmFilterStroke::GetCentroid() const {
200   if (samples_.size() == 0) {
201     return gfx::PointF(0., 0.);
202   }
203   return gfx::ScalePoint(unscaled_centroid_, 1.f / samples_.size());
204 }
205 
samples() const206 const std::deque<PalmFilterSample>& PalmFilterStroke::samples() const {
207   return samples_;
208 }
209 
tracking_id() const210 int PalmFilterStroke::tracking_id() const {
211   return tracking_id_;
212 }
213 
Duration() const214 base::TimeDelta PalmFilterStroke::Duration() const {
215   if (samples_.empty()) {
216     LOG(DFATAL) << "No samples available";
217     return base::Milliseconds(0);
218   }
219   return samples_.back().time - first_sample_time_;
220 }
221 
PreviousDuration() const222 base::TimeDelta PalmFilterStroke::PreviousDuration() const {
223   if (samples_.size() <= 1) {
224     LOG(DFATAL) << "Not enough samples";
225     return base::Milliseconds(0);
226   }
227   const PalmFilterSample& secondToLastSample = samples_.rbegin()[1];
228   return secondToLastSample.time - first_sample_time_;
229 }
230 
LastSampleCrossed(base::TimeDelta duration) const231 bool PalmFilterStroke::LastSampleCrossed(base::TimeDelta duration) const {
232   if (samples_.size() <= 1) {
233     // If there's only 1 sample, stroke just started and Duration() is zero.
234     return false;
235   }
236   return PreviousDuration() < duration && duration <= Duration();
237 }
238 
GetSampleAt(base::TimeTicks time) const239 PalmFilterSample PalmFilterStroke::GetSampleAt(base::TimeTicks time) const {
240   size_t i = 0;
241   for (; i < samples_.size() && samples_[i].time < time; ++i) {
242   }
243 
244   if (i < samples_.size() && !samples_.empty() && samples_[i].time == time) {
245     return samples_[i];
246   }
247   if (i == 0 || i == samples_.size()) {
248     LOG(DFATAL) << "Invalid index: " << i
249                 << ", can't interpolate for time: " << time;
250     return {};
251   }
252   return GetSampleAtTime(time, samples_[i - 1], samples_[i]);
253 }
254 
samples_seen() const255 uint64_t PalmFilterStroke::samples_seen() const {
256   return samples_seen_;
257 }
258 
MaxMajorRadius() const259 float PalmFilterStroke::MaxMajorRadius() const {
260   float maximum = 0.0;
261   for (const auto& sample : samples_) {
262     maximum = std::max(maximum, sample.major_radius);
263   }
264   return maximum;
265 }
266 
BiggestSize() const267 float PalmFilterStroke::BiggestSize() const {
268   float biggest = 0;
269   for (const auto& sample : samples_) {
270     float size;
271     if (sample.minor_radius <= 0) {
272       size = sample.major_radius * sample.major_radius;
273     } else {
274       size = sample.major_radius * sample.minor_radius;
275     }
276     biggest = std::max(biggest, size);
277   }
278   return biggest;
279 }
280 
addLinePrefix(std::string str,const std::string & prefix)281 static std::string addLinePrefix(std::string str, const std::string& prefix) {
282   std::stringstream ss;
283   bool newLineStarted = true;
284   for (const auto& ch : str) {
285     if (newLineStarted) {
286       ss << prefix;
287       newLineStarted = false;
288     }
289     if (ch == '\n') {
290       newLineStarted = true;
291     }
292     ss << ch;
293   }
294   return ss.str();
295 }
296 
operator <<(std::ostream & out,const gfx::PointF & point)297 std::ostream& operator<<(std::ostream& out, const gfx::PointF& point) {
298   out << "PointF(" << point.x() << ", " << point.y() << ")";
299   return out;
300 }
301 
operator <<(std::ostream & out,const gfx::Vector2dF & vec)302 std::ostream& operator<<(std::ostream& out, const gfx::Vector2dF& vec) {
303   out << "Vector2dF(" << vec.x() << ", " << vec.y() << ")";
304   return out;
305 }
306 
operator <<(std::ostream & out,const PalmFilterDeviceInfo & info)307 std::ostream& operator<<(std::ostream& out, const PalmFilterDeviceInfo& info) {
308   out << "PalmFilterDeviceInfo(max_x=" << info.max_x;
309   out << ", max_y=" << info.max_y;
310   out << ", x_res=" << info.x_res;
311   out << ", y_res=" << info.y_res;
312   out << ", major_radius_res=" << info.major_radius_res;
313   out << ", minor_radius_res=" << info.minor_radius_res;
314   out << ", minor_radius_supported=" << info.minor_radius_supported;
315   out << ")";
316   return out;
317 }
318 
operator <<(std::ostream & out,const PalmFilterSample & sample)319 std::ostream& operator<<(std::ostream& out, const PalmFilterSample& sample) {
320   out << "PalmFilterSample(major=" << sample.major_radius
321       << ", minor=" << sample.minor_radius << ", pressure=" << sample.pressure
322       << ", edge=" << sample.edge << ", tracking_id=" << sample.tracking_id
323       << ", point=" << sample.point << ", time=" << sample.time << ")";
324   return out;
325 }
326 
operator <<(std::ostream & out,const PalmFilterStroke & stroke)327 std::ostream& operator<<(std::ostream& out, const PalmFilterStroke& stroke) {
328   out << "PalmFilterStroke(\n";
329   out << "  GetCentroid() = " << stroke.GetCentroid() << "\n";
330   out << "  BiggestSize() = " << stroke.BiggestSize() << "\n";
331   out << "  MaxMajorRadius() = " << stroke.MaxMajorRadius() << "\n";
332   std::stringstream stream;
333   stream << stroke.samples();
334   out << "  samples (" << stroke.samples().size() << " total): \n"
335       << addLinePrefix(stream.str(), "    ") << "\n";
336   out << "  samples_seen() = " << stroke.samples_seen() << "\n";
337   out << "  tracking_id() = " << stroke.tracking_id() << "\n";
338   out << "  max_sample_count_ = " << stroke.max_sample_count_ << "\n";
339   if (stroke.resample_period_) {
340     out << "  resample_period_ = " << *(stroke.resample_period_) << "\n";
341   } else {
342     out << "  resample_period_  = <not set>\n";
343   }
344   out << "  first_sample_time_ = " << stroke.first_sample_time_ << "\n";
345   out << "  unscaled_centroid_ = " << stroke.unscaled_centroid_ << "\n";
346   out << "  unscaled_centroid_sum_error_ = "
347       << stroke.unscaled_centroid_sum_error_ << "\n";
348   out << ")\n";
349   return out;
350 }
351 
352 }  // namespace ui
353