1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
6
7 #include <algorithm>
8 #include <functional>
9 #include <memory>
10 #include <queue>
11 #include <string>
12 #include <tuple>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16
17 #include "base/command_line.h"
18 #include "base/json/json_reader.h"
19 #include "base/logging.h"
20 #include "base/values.h"
21 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
22 #include "ui/events/ozone/evdev/event_device_info.h"
23 #else
24 #include <linux/input-event-codes.h>
25 #endif
26 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
27 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
28 #include "ui/events/ozone/features.h"
29
30 namespace ui {
31 namespace {
32 // Returns the Euclidean distance between two points.
EuclideanDistance(const gfx::PointF & a,const gfx::PointF & b)33 float EuclideanDistance(const gfx::PointF& a, const gfx::PointF& b) {
34 return (a - b).Length();
35 }
36
IsEarlyStageSample(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)37 bool IsEarlyStageSample(
38 const PalmFilterStroke& stroke,
39 const NeuralStylusPalmDetectionFilterModelConfig& config) {
40 if (!config.resample_period) {
41 return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
42 config.early_stage_sample_counts.end();
43 }
44 // Duration is not well-defined for sample_count <= 1, so we handle
45 // it separately.
46 if (stroke.samples().empty()) {
47 return false;
48 }
49 if (stroke.samples().size() == 1) {
50 return config.early_stage_sample_counts.find(1) !=
51 config.early_stage_sample_counts.end();
52 }
53 for (const uint32_t sample_count : config.early_stage_sample_counts) {
54 const base::TimeDelta duration = config.GetEquivalentDuration(sample_count);
55 // Previous sample must not have passed the 'duration' threshold, but the
56 // current sample must pass the threshold
57 if (stroke.LastSampleCrossed(duration)) {
58 return true;
59 }
60 }
61 return false;
62 }
63
HasDecidedStroke(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)64 bool HasDecidedStroke(
65 const PalmFilterStroke& stroke,
66 const NeuralStylusPalmDetectionFilterModelConfig& config) {
67 if (!config.resample_period) {
68 return stroke.samples_seen() >= config.max_sample_count;
69 }
70 const base::TimeDelta max_duration =
71 config.GetEquivalentDuration(config.max_sample_count);
72 return stroke.Duration() >= max_duration;
73 }
74
IsVeryShortStroke(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)75 bool IsVeryShortStroke(
76 const PalmFilterStroke& stroke,
77 const NeuralStylusPalmDetectionFilterModelConfig& config) {
78 if (!config.resample_period) {
79 return stroke.samples_seen() < config.min_sample_count;
80 }
81 return stroke.Duration() <
82 config.GetEquivalentDuration(config.min_sample_count);
83 }
84
85 /**
86 * The provided stroke must be a neighbor stroke rather than a stroke currently
87 * being evaluated. The parameter 'neighbor_min_sample_count' might be different
88 * from the config, depending on the specific usage in the caller.
89 */
HasInsufficientDataAsNeighbor(const PalmFilterStroke & neighbor_stroke,size_t neighbor_min_sample_count,const NeuralStylusPalmDetectionFilterModelConfig & config)90 bool HasInsufficientDataAsNeighbor(
91 const PalmFilterStroke& neighbor_stroke,
92 size_t neighbor_min_sample_count,
93 const NeuralStylusPalmDetectionFilterModelConfig& config) {
94 if (!config.resample_period) {
95 return neighbor_stroke.samples().size() < neighbor_min_sample_count;
96 }
97 return neighbor_stroke.Duration() <
98 config.GetEquivalentDuration(neighbor_min_sample_count);
99 }
100
101 } // namespace
102
NeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo,std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,SharedPalmDetectionFilterState * shared_palm_state)103 NeuralStylusPalmDetectionFilter::NeuralStylusPalmDetectionFilter(
104 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
105 const EventDeviceInfo& devinfo,
106 #else
107 PalmFilterDeviceInfo palm_filter_device_info,
108 #endif
109 std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,
110 SharedPalmDetectionFilterState* shared_palm_state)
111 : PalmDetectionFilter(shared_palm_state),
112 tracking_ids_count_within_session_(0),
113 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
114 palm_filter_dev_info_(CreatePalmFilterDeviceInfo(devinfo)),
115 #else
116 palm_filter_dev_info_(palm_filter_device_info),
117 #endif
118 model_(std::move(palm_model)) {
119 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
120 DCHECK(CompatibleWithNeuralStylusPalmDetectionFilter(devinfo))
121 << "One should run compatible check before instantiation.";
122 #endif
123 }
124
~NeuralStylusPalmDetectionFilter()125 NeuralStylusPalmDetectionFilter::~NeuralStylusPalmDetectionFilter() {}
126
FindBiggestNeighborsWithin(int neighbor_count,unsigned long neighbor_min_sample_count,float max_distance,const PalmFilterStroke & stroke,std::vector<std::pair<float,int>> * biggest_strokes) const127 void NeuralStylusPalmDetectionFilter::FindBiggestNeighborsWithin(
128 int neighbor_count,
129 unsigned long neighbor_min_sample_count,
130 float max_distance,
131 const PalmFilterStroke& stroke,
132 std::vector<std::pair<float, int>>* biggest_strokes) const {
133 if (neighbor_count <= 0) {
134 return;
135 }
136 // Tuple of {size, distance, stroke_id.}
137 std::priority_queue<std::tuple<float, float, int>> biggest_strokes_queue;
138 for (const auto& lookup : strokes_) {
139 const PalmFilterStroke& neighbor = lookup.second;
140 if (neighbor.tracking_id() == stroke.tracking_id()) {
141 continue;
142 }
143 if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
144 model_->config())) {
145 continue;
146 }
147 float distance =
148 EuclideanDistance(neighbor.GetCentroid(), stroke.GetCentroid());
149 if (distance > max_distance) {
150 continue;
151 }
152 float size = neighbor.BiggestSize();
153 biggest_strokes_queue.push(
154 std::make_tuple(size, distance, neighbor.tracking_id()));
155 }
156 for (int i = 0; i < neighbor_count && !biggest_strokes_queue.empty(); ++i) {
157 const auto big_stroke = biggest_strokes_queue.top();
158 biggest_strokes_queue.pop();
159 biggest_strokes->emplace_back(std::get<1>(big_stroke),
160 std::get<2>(big_stroke));
161 }
162 }
163
FindNearestNeighborsWithin(int neighbor_count,unsigned long neighbor_min_sample_count,float max_distance,const PalmFilterStroke & stroke,std::vector<std::pair<float,int>> * nearest_strokes) const164 void NeuralStylusPalmDetectionFilter::FindNearestNeighborsWithin(
165 int neighbor_count,
166 unsigned long neighbor_min_sample_count,
167 float max_distance,
168 const PalmFilterStroke& stroke,
169 std::vector<std::pair<float, int>>* nearest_strokes) const {
170 using StrokeId = int;
171 using Distance = float;
172 using DistanceWithStrokeId = std::pair<Distance, StrokeId>;
173 std::priority_queue<DistanceWithStrokeId, std::vector<DistanceWithStrokeId>,
174 std::greater<DistanceWithStrokeId>>
175 queue;
176 if (neighbor_count <= 0) {
177 return;
178 }
179 for (const auto& lookup : strokes_) {
180 const PalmFilterStroke& neighbor = lookup.second;
181 if (neighbor.tracking_id() == stroke.tracking_id()) {
182 continue;
183 }
184 if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
185 model_->config())) {
186 continue;
187 }
188 float distance =
189 EuclideanDistance(neighbor.GetCentroid(), stroke.GetCentroid());
190 if (distance < max_distance) {
191 queue.push(std::make_pair(distance, neighbor.tracking_id()));
192 }
193 }
194 for (int i = 0; i < neighbor_count && !queue.empty(); ++i) {
195 nearest_strokes->push_back(queue.top());
196 queue.pop();
197 }
198 }
199
Filter(const std::vector<InProgressTouchEvdev> & touches,base::TimeTicks time,std::bitset<kNumTouchEvdevSlots> * slots_to_hold,std::bitset<kNumTouchEvdevSlots> * slots_to_suppress)200 void NeuralStylusPalmDetectionFilter::Filter(
201 const std::vector<InProgressTouchEvdev>& touches,
202 base::TimeTicks time,
203 std::bitset<kNumTouchEvdevSlots>* slots_to_hold,
204 std::bitset<kNumTouchEvdevSlots>* slots_to_suppress) {
205 EraseOldStrokes(time);
206 slots_to_hold->reset();
207 slots_to_suppress->reset();
208 std::unordered_set<int> slots_to_decide;
209 std::vector<int> ended_tracking_ids;
210 uint32_t total_finger_touching = 0;
211 for (const auto& touch : touches) {
212 if (touch.touching && touch.tool_code != BTN_TOOL_PEN) {
213 total_finger_touching++;
214 if (!touch.was_touching) {
215 shared_palm_state_->latest_finger_touch_time = time;
216 }
217 }
218 // Ignore touch events that are not touches.
219 if (!touch.touching && !touch.was_touching) {
220 continue;
221 }
222 int tracking_id = touch.tracking_id;
223 const size_t slot = touch.slot;
224 if (!touch.was_touching) {
225 // New stroke, so add the new stroke to the stroke list.
226 DCHECK_NE(tracking_id, -1);
227 DCHECK(strokes_.count(tracking_id) == 0)
228 << " Tracking id " << tracking_id;
229
230 strokes_.emplace(tracking_id,
231 PalmFilterStroke(model_->config(), tracking_id));
232 tracking_ids_[slot] = tracking_id;
233 is_palm_.set(slot, false);
234 is_delay_.set(slot, false);
235 }
236
237 const bool end_of_stroke = tracking_id == -1;
238 if (end_of_stroke) {
239 // Recover the tracking ID.
240 tracking_id = tracking_ids_[slot];
241 }
242
243 DCHECK_NE(tracking_id, -1);
244
245 auto insert_result = active_tracking_ids_.insert(tracking_id);
246 // New tracking_id.
247 if (insert_result.second)
248 tracking_ids_count_within_session_++;
249
250 // Find the stroke in the stroke list.
251 auto stroke_it = strokes_.find(tracking_id);
252
253 if (stroke_it == strokes_.end()) {
254 // TODO(crbug.com/1256926): Work out why this is hit on long presses.
255 DVLOG(1) << "No stroke found, continue.";
256 continue;
257 }
258
259 const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
260
261 PalmFilterStroke& stroke = stroke_it->second;
262 if (end_of_stroke) {
263 // This is a stroke that hasn't had a decision yet, so we force decide.
264 if (!HasDecidedStroke(stroke, config)) {
265 slots_to_decide.insert(slot);
266 }
267
268 ended_tracking_ids.push_back(tracking_id);
269 continue;
270 }
271
272 // Add the sample to the stroke.
273 stroke.ProcessSample(CreatePalmFilterSample(touch, time, model_->config(),
274 palm_filter_dev_info_));
275 if (!is_palm_.test(slot) && ShouldDecideStroke(stroke)) {
276 // slots_to_decide will have is_delay_ set to false anyway, no need to do
277 // the delay detection.
278 slots_to_decide.insert(slot);
279 continue;
280 }
281
282 // Heuristic delay detection.
283 if (config.heuristic_delay_start_if_palm && !end_of_stroke &&
284 !HasDecidedStroke(stroke, config) && IsHeuristicPalmStroke(stroke)) {
285 // A stroke that we _think_ may be a palm, but is too short to decide
286 // yet. So we mark for delay for now.
287 is_delay_.set(slot, true);
288 }
289
290 // Early stage delay detection that marks suspicious palms for delay.
291 if (!is_delay_.test(slot) && config.nn_delay_start_if_palm &&
292 IsEarlyStageSample(stroke, config)) {
293 VLOG(1) << "About to run a early_stage prediction.";
294 if (DetectSpuriousStroke(ExtractFeatures(tracking_id),
295 model_->config().output_threshold)) {
296 VLOG(1) << "hold detected.";
297 is_delay_.set(slot, true);
298 }
299 }
300 }
301
302 for (const int slot : slots_to_decide) {
303 is_delay_.set(slot, false);
304 is_palm_.set(slot, false);
305 int tracking_id = tracking_ids_[slot];
306 auto lookup = strokes_.find(tracking_id);
307 if (lookup == strokes_.end()) {
308 LOG(DFATAL) << "Unable to find marked stroke.";
309 continue;
310 }
311 const auto& stroke = lookup->second;
312 if (IsVeryShortStroke(stroke, model_->config())) {
313 // in very short strokes: we use a heuristic.
314 is_palm_.set(slot, IsHeuristicPalmStroke(stroke));
315 continue;
316 }
317 is_palm_.set(slot, DetectSpuriousStroke(ExtractFeatures(tracking_id),
318 model_->config().output_threshold));
319 if (is_palm_.test(slot)) {
320 shared_palm_state_->latest_palm_touch_time = time;
321 }
322 }
323
324 for (const int tracking_id : ended_tracking_ids) {
325 active_tracking_ids_.erase(tracking_id);
326 }
327
328 *slots_to_suppress |= is_palm_;
329 *slots_to_hold |= is_delay_;
330
331 shared_palm_state_->active_palm_touches = is_palm_.count();
332 shared_palm_state_->active_finger_touches =
333 total_finger_touching - is_palm_.count();
334 }
335
ShouldDecideStroke(const PalmFilterStroke & stroke) const336 bool NeuralStylusPalmDetectionFilter::ShouldDecideStroke(
337 const PalmFilterStroke& stroke) const {
338 const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
339 // Inference only executed once per stroke
340 if (!config.resample_period) {
341 return stroke.samples_seen() == config.max_sample_count;
342 }
343 return stroke.LastSampleCrossed(
344 config.GetEquivalentDuration(config.max_sample_count));
345 }
346
IsHeuristicPalmStroke(const PalmFilterStroke & stroke) const347 bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
348 const PalmFilterStroke& stroke) const {
349 const auto& config = model_->config();
350 if (config.resample_period) {
351 if (stroke.Duration() >
352 config.GetEquivalentDuration(config.max_sample_count)) {
353 LOG(DFATAL)
354 << "Should not call this method on long strokes. Got duration = "
355 << stroke.Duration();
356 return false;
357 }
358 } else {
359 if (stroke.samples().size() >= config.max_sample_count) {
360 LOG(DFATAL) << "Should not call this method on long strokes.";
361 return false;
362 }
363 }
364
365 if (config.heuristic_palm_touch_limit > 0.0) {
366 if (stroke.MaxMajorRadius() >= config.heuristic_palm_touch_limit) {
367 VLOG(1) << "IsHeuristicPalm: Yes major radius.";
368 return true;
369 }
370 }
371 if (config.heuristic_palm_area_limit > 0.0) {
372 if (stroke.BiggestSize() >= config.heuristic_palm_area_limit) {
373 VLOG(1) << "IsHeuristicPalm: Yes area.";
374 return true;
375 }
376 std::vector<std::pair<float, int>> biggest_strokes;
377 FindBiggestNeighborsWithin(
378 1 /* neighbors */, 1 /* neighbor min sample count */,
379 config.max_neighbor_distance_in_mm, stroke, &biggest_strokes);
380 if (!biggest_strokes.empty() &&
381 strokes_.find(biggest_strokes[0].second)->second.BiggestSize() >=
382 config.heuristic_palm_area_limit) {
383 VLOG(1) << "IsHeuristicPalm: Yes neighbor area.";
384 return true;
385 }
386 }
387 VLOG(1) << "IsHeuristicPalm: No.";
388 return false;
389 }
390
DetectSpuriousStroke(const std::vector<float> & features,float threshold) const391 bool NeuralStylusPalmDetectionFilter::DetectSpuriousStroke(
392 const std::vector<float>& features,
393 float threshold) const {
394 auto inference_value = model_->Inference(features);
395 if (VLOG_IS_ON(1)) {
396 VLOG(1) << "Running Inference, features are:";
397 for (std::vector<float>::size_type i = 0; i < features.size(); ++i) {
398 VLOG(1) << "Feature " << i << " is " << features[i];
399 }
400 VLOG(1) << "Inference value is : " << inference_value;
401 }
402 return inference_value >= threshold;
403 }
404
ExtractFeatures(int tracking_id) const405 std::vector<float> NeuralStylusPalmDetectionFilter::ExtractFeatures(
406 int tracking_id) const {
407 std::vector<float> features;
408 const PalmFilterStroke& stroke = strokes_.at(tracking_id);
409 AppendFeatures(stroke, &features);
410 const int features_per_stroke = features.size();
411 std::vector<std::pair<float, int>> nearest_strokes, biggest_strokes;
412 const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
413 FindNearestNeighborsWithin(
414 config.nearest_neighbor_count, config.neighbor_min_sample_count,
415 config.max_neighbor_distance_in_mm, stroke, &nearest_strokes);
416 FindBiggestNeighborsWithin(
417 config.biggest_near_neighbor_count, config.neighbor_min_sample_count,
418 config.max_neighbor_distance_in_mm, stroke, &biggest_strokes);
419 for (uint32_t i = 0; i < config.nearest_neighbor_count; ++i) {
420 if (i < nearest_strokes.size()) {
421 const auto& nearest_stroke = nearest_strokes[i];
422 AppendFeaturesAsNeighbor(strokes_.at(nearest_stroke.second),
423 nearest_stroke.first, &features);
424 } else {
425 features.resize(
426 features.size() + features_per_stroke + kExtraFeaturesForNeighbor, 0);
427 }
428 }
429
430 for (uint32_t i = 0; i < config.biggest_near_neighbor_count; ++i) {
431 if (i < biggest_strokes.size()) {
432 const auto& biggest_stroke = biggest_strokes[i];
433 AppendFeaturesAsNeighbor(strokes_.at(biggest_stroke.second),
434 biggest_stroke.first, &features);
435 } else {
436 features.resize(
437 features.size() + features_per_stroke + kExtraFeaturesForNeighbor, 0);
438 }
439 }
440
441 if (config.use_tracking_id_count) {
442 features.push_back(tracking_ids_count_within_session_);
443 }
444
445 if (config.use_active_tracking_id_count) {
446 features.push_back(active_tracking_ids_.size());
447 }
448
449 return features;
450 }
451
AppendFeatures(const PalmFilterStroke & stroke,std::vector<float> * features) const452 void NeuralStylusPalmDetectionFilter::AppendFeatures(
453 const PalmFilterStroke& stroke,
454 std::vector<float>* features) const {
455 if (model_->config().resample_period) {
456 return AppendResampledFeatures(stroke, features);
457 }
458 const int size = stroke.samples().size();
459 for (int i = 0; i < size; ++i) {
460 const PalmFilterSample& sample = stroke.samples()[i];
461 features->push_back(sample.major_radius);
462 features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
463 : sample.minor_radius);
464 float distance = 0;
465 if (i != 0) {
466 distance = EuclideanDistance(stroke.samples()[i - 1].point, sample.point);
467 }
468 features->push_back(distance);
469 features->push_back(sample.edge);
470 features->push_back(1.0); // existence.
471 }
472 const int padding = model_->config().max_sample_count - size;
473 DCHECK_GE(padding, 0);
474
475 for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
476 features->push_back(0.0);
477 }
478 // "fill proportion."
479 features->push_back(static_cast<float>(size) /
480 model_->config().max_sample_count);
481 features->push_back(EuclideanDistance(stroke.samples().front().point,
482 stroke.samples().back().point));
483
484 // Start sequence number. 0 is min.
485 uint32_t samples_seen = stroke.samples_seen();
486 if (samples_seen < model_->config().max_sample_count) {
487 features->push_back(0);
488 } else {
489 features->push_back(samples_seen - model_->config().max_sample_count);
490 }
491 }
492
493 /**
494 * The flow here is similar to 'AppendFeatures' above, but we rely on the
495 * timing of the samples rather than on the explicit number / position of
496 * samples.
497 */
AppendResampledFeatures(const PalmFilterStroke & stroke,std::vector<float> * features) const498 void NeuralStylusPalmDetectionFilter::AppendResampledFeatures(
499 const PalmFilterStroke& stroke,
500 std::vector<float>* features) const {
501 size_t sample_count = 0;
502 const base::TimeTicks& first_time = stroke.samples()[0].time;
503 const base::TimeDelta& resample_period = *model_->config().resample_period;
504 const base::TimeDelta max_duration =
505 model_->config().GetEquivalentDuration(model_->config().max_sample_count);
506 for (auto time = first_time; (time - first_time) <= max_duration &&
507 time <= stroke.samples().back().time;
508 time += resample_period) {
509 sample_count++;
510 const PalmFilterSample& sample = stroke.GetSampleAt(time);
511 features->push_back(sample.major_radius);
512 features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
513 : sample.minor_radius);
514 float distance = 0;
515 if (time != first_time) {
516 distance = EuclideanDistance(
517 stroke.GetSampleAt(time - resample_period).point, sample.point);
518 }
519 features->push_back(distance);
520 features->push_back(sample.edge);
521 features->push_back(1.0); // existence.
522 }
523 const int padding = model_->config().max_sample_count - sample_count;
524 DCHECK_GE(padding, 0);
525
526 for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
527 features->push_back(0.0);
528 }
529 // "fill proportion."
530 features->push_back(static_cast<float>(sample_count) /
531 model_->config().max_sample_count);
532 features->push_back(EuclideanDistance(stroke.samples().front().point,
533 stroke.samples().back().point));
534
535 // Start sequence number. 0 is min.
536 uint32_t samples_seen =
537 (stroke.Duration() / (*model_->config().resample_period)) + 1;
538 if (samples_seen < model_->config().max_sample_count) {
539 features->push_back(0);
540 } else {
541 features->push_back(samples_seen - model_->config().max_sample_count);
542 }
543 }
544
AppendFeaturesAsNeighbor(const PalmFilterStroke & stroke,float distance,std::vector<float> * features) const545 void NeuralStylusPalmDetectionFilter::AppendFeaturesAsNeighbor(
546 const PalmFilterStroke& stroke,
547 float distance,
548 std::vector<float>* features) const {
549 features->push_back(1); // existence.
550 features->push_back(distance);
551 AppendFeatures(stroke, features);
552 }
553
554 const int NeuralStylusPalmDetectionFilter::kExtraFeaturesForNeighbor = 2;
555 const int NeuralStylusPalmDetectionFilter::kFeaturesPerSample = 5;
556
557 const char NeuralStylusPalmDetectionFilter::kFilterName[] =
558 "NeuralStylusPalmDetectionFilter";
FilterNameForTesting() const559 std::string NeuralStylusPalmDetectionFilter::FilterNameForTesting() const {
560 return kFilterName;
561 }
562
563 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
564 bool NeuralStylusPalmDetectionFilter::
CompatibleWithNeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo)565 CompatibleWithNeuralStylusPalmDetectionFilter(
566 const EventDeviceInfo& devinfo) {
567 return NeuralStylusPalmDetectionFilter::
568 CompatibleWithNeuralStylusPalmDetectionFilter(
569 devinfo, base::CommandLine::ForCurrentProcess()->GetSwitchValueASCII(
570 kOzoneNNPalmSwitchName));
571 }
572
573 bool NeuralStylusPalmDetectionFilter::
CompatibleWithNeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo,const std::string & ozone_params_switch_string)574 CompatibleWithNeuralStylusPalmDetectionFilter(
575 const EventDeviceInfo& devinfo,
576 const std::string& ozone_params_switch_string) {
577 if (devinfo.HasStylus()) {
578 return false;
579 }
580 // Though we like having abs_mt_touch_minor, we don't need it.
581 auto code_check = [&devinfo](int code) {
582 if (!devinfo.HasAbsEvent(code)) {
583 return false;
584 }
585 const auto absinfo = devinfo.GetAbsInfoByCode(code);
586 // Ensure minimum is 0, maximum is greater than the minimum.
587 if (absinfo.minimum != 0 || absinfo.maximum <= absinfo.minimum) {
588 return false;
589 }
590 return true;
591 };
592
593 static constexpr int kRequiredAbsMtCodes[] = {
594 ABS_MT_POSITION_X, ABS_MT_POSITION_Y, ABS_MT_TOUCH_MAJOR};
595 if (!std::all_of(std::begin(kRequiredAbsMtCodes),
596 std::end(kRequiredAbsMtCodes), code_check)) {
597 return false;
598 }
599
600 // Optionally, we use touch_minor if it's around, so check it's good if it
601 // is present.
602 if (devinfo.HasAbsEvent(ABS_MT_TOUCH_MINOR) &&
603 !code_check(ABS_MT_TOUCH_MINOR)) {
604 return false;
605 }
606 // Only work with internal touchscreens.
607 if (devinfo.device_type() != INPUT_DEVICE_INTERNAL) {
608 return false;
609 }
610
611 // Check the switch string.
612
613 absl::optional<base::Value> value =
614 base::JSONReader::Read(ozone_params_switch_string);
615 if (value != absl::nullopt && !ozone_params_switch_string.empty()) {
616 if (!value->is_dict()) {
617 return false;
618 }
619 // If the key isn't set, default to false.
620 if (value->FindKey(kOzoneNNPalmTouchCompatibleProperty) == nullptr) {
621 return false;
622 }
623 std::string* touch_string_val =
624 value->FindStringKey(kOzoneNNPalmTouchCompatibleProperty);
625 if (touch_string_val != nullptr) {
626 if (*touch_string_val == "false") {
627 return false;
628 } else if (*touch_string_val == "true") {
629 return true;
630 } else {
631 LOG(DFATAL) << "Unexpected value for nnpalm touch compatible. expected "
632 "\"true\" or \"false\" . Got: "
633 << *touch_string_val;
634 }
635 }
636 }
637 return true;
638 }
639 #endif
640
EraseOldStrokes(base::TimeTicks time)641 void NeuralStylusPalmDetectionFilter::EraseOldStrokes(base::TimeTicks time) {
642 const base::TimeDelta max_age = model_->config().max_dead_neighbor_time;
643 for (auto it = strokes_.begin(); it != strokes_.end();) {
644 DCHECK(!it->second.samples().empty());
645 const base::TimeTicks most_recent_sample_time =
646 it->second.samples().back().time;
647 const auto age = time - most_recent_sample_time;
648 if (age > max_age) {
649 it = strokes_.erase(it);
650 } else {
651 ++it;
652 }
653 }
654
655 // If the blank time is more than max_blank_time, starts a new session.
656 if (time - previous_report_time_ > model_->config().max_blank_time) {
657 tracking_ids_count_within_session_ = 0;
658 active_tracking_ids_.clear();
659 }
660 previous_report_time_ = time;
661 }
662
addLinePrefix(std::string str,const std::string & prefix)663 static std::string addLinePrefix(std::string str, const std::string& prefix) {
664 std::stringstream ss;
665 bool newLineStarted = true;
666 for (const auto& ch : str) {
667 if (newLineStarted) {
668 ss << prefix;
669 newLineStarted = false;
670 }
671 if (ch == '\n') {
672 newLineStarted = true;
673 }
674 ss << ch;
675 }
676 return ss.str();
677 }
678
operator <<(std::ostream & out,const NeuralStylusPalmDetectionFilter & filter)679 std::ostream& operator<<(std::ostream& out,
680 const NeuralStylusPalmDetectionFilter& filter) {
681 out << "NeuralStylusPalmDetectionFilter(\n";
682 out << " is_palm_ = " << filter.is_palm_ << "\n";
683 out << " is_delay_ = " << filter.is_delay_ << "\n";
684 out << " strokes_ =\n";
685 std::stringstream strokes;
686 strokes << filter.strokes_;
687 out << addLinePrefix(strokes.str(), " ") << "\n";
688 out << " previous_report_time_ = " << filter.previous_report_time_ << "\n";
689 out << " active_tracking_ids_ = " << filter.active_tracking_ids_ << "\n";
690 out << " tracking_ids_count_within_session_ = "
691 << filter.tracking_ids_count_within_session_ << "\n";
692 out << " tracking_ids = [";
693 for (int i = 0; i < kNumTouchEvdevSlots; i++) {
694 out << filter.tracking_ids_[i] << ", ";
695 }
696 out << "]\n";
697
698 out << " palm_filter_dev_info_ = " << filter.palm_filter_dev_info_ << "\n";
699 out << ")\n";
700
701 return out;
702 }
703
704 } // namespace ui
705