1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter.h"
6 
7 #include <algorithm>
8 #include <functional>
9 #include <memory>
10 #include <queue>
11 #include <string>
12 #include <tuple>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/command_line.h"
18 #include "base/json/json_reader.h"
19 #include "base/logging.h"
20 #include "base/values.h"
21 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
22 #include "ui/events/ozone/evdev/event_device_info.h"
23 #else
24 #include <linux/input-event-codes.h>
25 #endif
26 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
27 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
28 #include "ui/events/ozone/features.h"
29 
30 namespace ui {
31 namespace {
32 // Returns the Euclidean distance between two points.
EuclideanDistance(const gfx::PointF & a,const gfx::PointF & b)33 float EuclideanDistance(const gfx::PointF& a, const gfx::PointF& b) {
34   return (a - b).Length();
35 }
36 
IsEarlyStageSample(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)37 bool IsEarlyStageSample(
38     const PalmFilterStroke& stroke,
39     const NeuralStylusPalmDetectionFilterModelConfig& config) {
40   if (!config.resample_period) {
41     return config.early_stage_sample_counts.find(stroke.samples_seen()) !=
42            config.early_stage_sample_counts.end();
43   }
44   // Duration is not well-defined for sample_count <= 1, so we handle
45   // it separately.
46   if (stroke.samples().empty()) {
47     return false;
48   }
49   if (stroke.samples().size() == 1) {
50     return config.early_stage_sample_counts.find(1) !=
51            config.early_stage_sample_counts.end();
52   }
53   for (const uint32_t sample_count : config.early_stage_sample_counts) {
54     const base::TimeDelta duration = config.GetEquivalentDuration(sample_count);
55     // Previous sample must not have passed the 'duration' threshold, but the
56     // current sample must pass the threshold
57     if (stroke.LastSampleCrossed(duration)) {
58       return true;
59     }
60   }
61   return false;
62 }
63 
HasDecidedStroke(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)64 bool HasDecidedStroke(
65     const PalmFilterStroke& stroke,
66     const NeuralStylusPalmDetectionFilterModelConfig& config) {
67   if (!config.resample_period) {
68     return stroke.samples_seen() >= config.max_sample_count;
69   }
70   const base::TimeDelta max_duration =
71       config.GetEquivalentDuration(config.max_sample_count);
72   return stroke.Duration() >= max_duration;
73 }
74 
IsVeryShortStroke(const PalmFilterStroke & stroke,const NeuralStylusPalmDetectionFilterModelConfig & config)75 bool IsVeryShortStroke(
76     const PalmFilterStroke& stroke,
77     const NeuralStylusPalmDetectionFilterModelConfig& config) {
78   if (!config.resample_period) {
79     return stroke.samples_seen() < config.min_sample_count;
80   }
81   return stroke.Duration() <
82          config.GetEquivalentDuration(config.min_sample_count);
83 }
84 
85 /**
86  * The provided stroke must be a neighbor stroke rather than a stroke currently
87  * being evaluated. The parameter 'neighbor_min_sample_count' might be different
88  * from the config, depending on the specific usage in the caller.
89  */
HasInsufficientDataAsNeighbor(const PalmFilterStroke & neighbor_stroke,size_t neighbor_min_sample_count,const NeuralStylusPalmDetectionFilterModelConfig & config)90 bool HasInsufficientDataAsNeighbor(
91     const PalmFilterStroke& neighbor_stroke,
92     size_t neighbor_min_sample_count,
93     const NeuralStylusPalmDetectionFilterModelConfig& config) {
94   if (!config.resample_period) {
95     return neighbor_stroke.samples().size() < neighbor_min_sample_count;
96   }
97   return neighbor_stroke.Duration() <
98          config.GetEquivalentDuration(neighbor_min_sample_count);
99 }
100 
101 }  // namespace
102 
NeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo,std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,SharedPalmDetectionFilterState * shared_palm_state)103 NeuralStylusPalmDetectionFilter::NeuralStylusPalmDetectionFilter(
104 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
105     const EventDeviceInfo& devinfo,
106 #else
107     PalmFilterDeviceInfo palm_filter_device_info,
108 #endif
109     std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,
110     SharedPalmDetectionFilterState* shared_palm_state)
111     : PalmDetectionFilter(shared_palm_state),
112       tracking_ids_count_within_session_(0),
113 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
114       palm_filter_dev_info_(CreatePalmFilterDeviceInfo(devinfo)),
115 #else
116       palm_filter_dev_info_(palm_filter_device_info),
117 #endif
118       model_(std::move(palm_model)) {
119 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
120   DCHECK(CompatibleWithNeuralStylusPalmDetectionFilter(devinfo))
121       << "One should run compatible check before instantiation.";
122 #endif
123 }
124 
~NeuralStylusPalmDetectionFilter()125 NeuralStylusPalmDetectionFilter::~NeuralStylusPalmDetectionFilter() {}
126 
FindBiggestNeighborsWithin(int neighbor_count,unsigned long neighbor_min_sample_count,float max_distance,const PalmFilterStroke & stroke,std::vector<std::pair<float,int>> * biggest_strokes) const127 void NeuralStylusPalmDetectionFilter::FindBiggestNeighborsWithin(
128     int neighbor_count,
129     unsigned long neighbor_min_sample_count,
130     float max_distance,
131     const PalmFilterStroke& stroke,
132     std::vector<std::pair<float, int>>* biggest_strokes) const {
133   if (neighbor_count <= 0) {
134     return;
135   }
136   // Tuple of {size, distance, stroke_id.}
137   std::priority_queue<std::tuple<float, float, int>> biggest_strokes_queue;
138   for (const auto& lookup : strokes_) {
139     const PalmFilterStroke& neighbor = lookup.second;
140     if (neighbor.tracking_id() == stroke.tracking_id()) {
141       continue;
142     }
143     if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
144                                       model_->config())) {
145       continue;
146     }
147     float distance =
148         EuclideanDistance(neighbor.GetCentroid(), stroke.GetCentroid());
149     if (distance > max_distance) {
150       continue;
151     }
152     float size = neighbor.BiggestSize();
153     biggest_strokes_queue.push(
154         std::make_tuple(size, distance, neighbor.tracking_id()));
155   }
156   for (int i = 0; i < neighbor_count && !biggest_strokes_queue.empty(); ++i) {
157     const auto big_stroke = biggest_strokes_queue.top();
158     biggest_strokes_queue.pop();
159     biggest_strokes->emplace_back(std::get<1>(big_stroke),
160                                   std::get<2>(big_stroke));
161   }
162 }
163 
FindNearestNeighborsWithin(int neighbor_count,unsigned long neighbor_min_sample_count,float max_distance,const PalmFilterStroke & stroke,std::vector<std::pair<float,int>> * nearest_strokes) const164 void NeuralStylusPalmDetectionFilter::FindNearestNeighborsWithin(
165     int neighbor_count,
166     unsigned long neighbor_min_sample_count,
167     float max_distance,
168     const PalmFilterStroke& stroke,
169     std::vector<std::pair<float, int>>* nearest_strokes) const {
170   using StrokeId = int;
171   using Distance = float;
172   using DistanceWithStrokeId = std::pair<Distance, StrokeId>;
173   std::priority_queue<DistanceWithStrokeId, std::vector<DistanceWithStrokeId>,
174                       std::greater<DistanceWithStrokeId>>
175       queue;
176   if (neighbor_count <= 0) {
177     return;
178   }
179   for (const auto& lookup : strokes_) {
180     const PalmFilterStroke& neighbor = lookup.second;
181     if (neighbor.tracking_id() == stroke.tracking_id()) {
182       continue;
183     }
184     if (HasInsufficientDataAsNeighbor(neighbor, neighbor_min_sample_count,
185                                       model_->config())) {
186       continue;
187     }
188     float distance =
189         EuclideanDistance(neighbor.GetCentroid(), stroke.GetCentroid());
190     if (distance < max_distance) {
191       queue.push(std::make_pair(distance, neighbor.tracking_id()));
192     }
193   }
194   for (int i = 0; i < neighbor_count && !queue.empty(); ++i) {
195     nearest_strokes->push_back(queue.top());
196     queue.pop();
197   }
198 }
199 
Filter(const std::vector<InProgressTouchEvdev> & touches,base::TimeTicks time,std::bitset<kNumTouchEvdevSlots> * slots_to_hold,std::bitset<kNumTouchEvdevSlots> * slots_to_suppress)200 void NeuralStylusPalmDetectionFilter::Filter(
201     const std::vector<InProgressTouchEvdev>& touches,
202     base::TimeTicks time,
203     std::bitset<kNumTouchEvdevSlots>* slots_to_hold,
204     std::bitset<kNumTouchEvdevSlots>* slots_to_suppress) {
205   EraseOldStrokes(time);
206   slots_to_hold->reset();
207   slots_to_suppress->reset();
208   std::unordered_set<int> slots_to_decide;
209   std::vector<int> ended_tracking_ids;
210   uint32_t total_finger_touching = 0;
211   for (const auto& touch : touches) {
212     if (touch.touching && touch.tool_code != BTN_TOOL_PEN) {
213       total_finger_touching++;
214       if (!touch.was_touching) {
215         shared_palm_state_->latest_finger_touch_time = time;
216       }
217     }
218     // Ignore touch events that are not touches.
219     if (!touch.touching && !touch.was_touching) {
220       continue;
221     }
222     int tracking_id = touch.tracking_id;
223     const size_t slot = touch.slot;
224     if (!touch.was_touching) {
225       // New stroke, so add the new stroke to the stroke list.
226       DCHECK_NE(tracking_id, -1);
227       DCHECK(strokes_.count(tracking_id) == 0)
228           << " Tracking id " << tracking_id;
229 
230       strokes_.emplace(tracking_id,
231                        PalmFilterStroke(model_->config(), tracking_id));
232       tracking_ids_[slot] = tracking_id;
233       is_palm_.set(slot, false);
234       is_delay_.set(slot, false);
235     }
236 
237     const bool end_of_stroke = tracking_id == -1;
238     if (end_of_stroke) {
239       // Recover the tracking ID.
240       tracking_id = tracking_ids_[slot];
241     }
242 
243     DCHECK_NE(tracking_id, -1);
244 
245     auto insert_result = active_tracking_ids_.insert(tracking_id);
246     // New tracking_id.
247     if (insert_result.second)
248       tracking_ids_count_within_session_++;
249 
250     // Find the stroke in the stroke list.
251     auto stroke_it = strokes_.find(tracking_id);
252 
253     if (stroke_it == strokes_.end()) {
254       // TODO(crbug.com/1256926): Work out why this is hit on long presses.
255       DVLOG(1) << "No stroke found, continue.";
256       continue;
257     }
258 
259     const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
260 
261     PalmFilterStroke& stroke = stroke_it->second;
262     if (end_of_stroke) {
263       // This is a stroke that hasn't had a decision yet, so we force decide.
264       if (!HasDecidedStroke(stroke, config)) {
265         slots_to_decide.insert(slot);
266       }
267 
268       ended_tracking_ids.push_back(tracking_id);
269       continue;
270     }
271 
272     // Add the sample to the stroke.
273     stroke.ProcessSample(CreatePalmFilterSample(touch, time, model_->config(),
274                                                 palm_filter_dev_info_));
275     if (!is_palm_.test(slot) && ShouldDecideStroke(stroke)) {
276       // slots_to_decide will have is_delay_ set to false anyway, no need to do
277       // the delay detection.
278       slots_to_decide.insert(slot);
279       continue;
280     }
281 
282     // Heuristic delay detection.
283     if (config.heuristic_delay_start_if_palm && !end_of_stroke &&
284         !HasDecidedStroke(stroke, config) && IsHeuristicPalmStroke(stroke)) {
285       //  A stroke that we _think_ may be a palm, but is too short to decide
286       //  yet. So we mark for delay for now.
287       is_delay_.set(slot, true);
288     }
289 
290     // Early stage delay detection that marks suspicious palms for delay.
291     if (!is_delay_.test(slot) && config.nn_delay_start_if_palm &&
292         IsEarlyStageSample(stroke, config)) {
293       VLOG(1) << "About to run a early_stage prediction.";
294       if (DetectSpuriousStroke(ExtractFeatures(tracking_id),
295                                model_->config().output_threshold)) {
296         VLOG(1) << "hold detected.";
297         is_delay_.set(slot, true);
298       }
299     }
300   }
301 
302   for (const int slot : slots_to_decide) {
303     is_delay_.set(slot, false);
304     is_palm_.set(slot, false);
305     int tracking_id = tracking_ids_[slot];
306     auto lookup = strokes_.find(tracking_id);
307     if (lookup == strokes_.end()) {
308       LOG(DFATAL) << "Unable to find marked stroke.";
309       continue;
310     }
311     const auto& stroke = lookup->second;
312     if (IsVeryShortStroke(stroke, model_->config())) {
313       // in very short strokes: we use a heuristic.
314       is_palm_.set(slot, IsHeuristicPalmStroke(stroke));
315       continue;
316     }
317     is_palm_.set(slot, DetectSpuriousStroke(ExtractFeatures(tracking_id),
318                                             model_->config().output_threshold));
319     if (is_palm_.test(slot)) {
320       shared_palm_state_->latest_palm_touch_time = time;
321     }
322   }
323 
324   for (const int tracking_id : ended_tracking_ids) {
325     active_tracking_ids_.erase(tracking_id);
326   }
327 
328   *slots_to_suppress |= is_palm_;
329   *slots_to_hold |= is_delay_;
330 
331   shared_palm_state_->active_palm_touches = is_palm_.count();
332   shared_palm_state_->active_finger_touches =
333       total_finger_touching - is_palm_.count();
334 }
335 
ShouldDecideStroke(const PalmFilterStroke & stroke) const336 bool NeuralStylusPalmDetectionFilter::ShouldDecideStroke(
337     const PalmFilterStroke& stroke) const {
338   const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
339   // Inference only executed once per stroke
340   if (!config.resample_period) {
341     return stroke.samples_seen() == config.max_sample_count;
342   }
343   return stroke.LastSampleCrossed(
344       config.GetEquivalentDuration(config.max_sample_count));
345 }
346 
IsHeuristicPalmStroke(const PalmFilterStroke & stroke) const347 bool NeuralStylusPalmDetectionFilter::IsHeuristicPalmStroke(
348     const PalmFilterStroke& stroke) const {
349   const auto& config = model_->config();
350   if (config.resample_period) {
351     if (stroke.Duration() >
352         config.GetEquivalentDuration(config.max_sample_count)) {
353       LOG(DFATAL)
354           << "Should not call this method on long strokes. Got duration = "
355           << stroke.Duration();
356       return false;
357     }
358   } else {
359     if (stroke.samples().size() >= config.max_sample_count) {
360       LOG(DFATAL) << "Should not call this method on long strokes.";
361       return false;
362     }
363   }
364 
365   if (config.heuristic_palm_touch_limit > 0.0) {
366     if (stroke.MaxMajorRadius() >= config.heuristic_palm_touch_limit) {
367       VLOG(1) << "IsHeuristicPalm: Yes major radius.";
368       return true;
369     }
370   }
371   if (config.heuristic_palm_area_limit > 0.0) {
372     if (stroke.BiggestSize() >= config.heuristic_palm_area_limit) {
373       VLOG(1) << "IsHeuristicPalm: Yes area.";
374       return true;
375     }
376     std::vector<std::pair<float, int>> biggest_strokes;
377     FindBiggestNeighborsWithin(
378         1 /* neighbors */, 1 /* neighbor min sample count */,
379         config.max_neighbor_distance_in_mm, stroke, &biggest_strokes);
380     if (!biggest_strokes.empty() &&
381         strokes_.find(biggest_strokes[0].second)->second.BiggestSize() >=
382             config.heuristic_palm_area_limit) {
383       VLOG(1) << "IsHeuristicPalm: Yes neighbor area.";
384       return true;
385     }
386   }
387   VLOG(1) << "IsHeuristicPalm: No.";
388   return false;
389 }
390 
DetectSpuriousStroke(const std::vector<float> & features,float threshold) const391 bool NeuralStylusPalmDetectionFilter::DetectSpuriousStroke(
392     const std::vector<float>& features,
393     float threshold) const {
394   auto inference_value = model_->Inference(features);
395   if (VLOG_IS_ON(1)) {
396     VLOG(1) << "Running Inference, features are:";
397     for (std::vector<float>::size_type i = 0; i < features.size(); ++i) {
398       VLOG(1) << "Feature " << i << " is " << features[i];
399     }
400     VLOG(1) << "Inference value is  : " << inference_value;
401   }
402   return inference_value >= threshold;
403 }
404 
ExtractFeatures(int tracking_id) const405 std::vector<float> NeuralStylusPalmDetectionFilter::ExtractFeatures(
406     int tracking_id) const {
407   std::vector<float> features;
408   const PalmFilterStroke& stroke = strokes_.at(tracking_id);
409   AppendFeatures(stroke, &features);
410   const int features_per_stroke = features.size();
411   std::vector<std::pair<float, int>> nearest_strokes, biggest_strokes;
412   const NeuralStylusPalmDetectionFilterModelConfig& config = model_->config();
413   FindNearestNeighborsWithin(
414       config.nearest_neighbor_count, config.neighbor_min_sample_count,
415       config.max_neighbor_distance_in_mm, stroke, &nearest_strokes);
416   FindBiggestNeighborsWithin(
417       config.biggest_near_neighbor_count, config.neighbor_min_sample_count,
418       config.max_neighbor_distance_in_mm, stroke, &biggest_strokes);
419   for (uint32_t i = 0; i < config.nearest_neighbor_count; ++i) {
420     if (i < nearest_strokes.size()) {
421       const auto& nearest_stroke = nearest_strokes[i];
422       AppendFeaturesAsNeighbor(strokes_.at(nearest_stroke.second),
423                                nearest_stroke.first, &features);
424     } else {
425       features.resize(
426           features.size() + features_per_stroke + kExtraFeaturesForNeighbor, 0);
427     }
428   }
429 
430   for (uint32_t i = 0; i < config.biggest_near_neighbor_count; ++i) {
431     if (i < biggest_strokes.size()) {
432       const auto& biggest_stroke = biggest_strokes[i];
433       AppendFeaturesAsNeighbor(strokes_.at(biggest_stroke.second),
434                                biggest_stroke.first, &features);
435     } else {
436       features.resize(
437           features.size() + features_per_stroke + kExtraFeaturesForNeighbor, 0);
438     }
439   }
440 
441   if (config.use_tracking_id_count) {
442     features.push_back(tracking_ids_count_within_session_);
443   }
444 
445   if (config.use_active_tracking_id_count) {
446     features.push_back(active_tracking_ids_.size());
447   }
448 
449   return features;
450 }
451 
AppendFeatures(const PalmFilterStroke & stroke,std::vector<float> * features) const452 void NeuralStylusPalmDetectionFilter::AppendFeatures(
453     const PalmFilterStroke& stroke,
454     std::vector<float>* features) const {
455   if (model_->config().resample_period) {
456     return AppendResampledFeatures(stroke, features);
457   }
458   const int size = stroke.samples().size();
459   for (int i = 0; i < size; ++i) {
460     const PalmFilterSample& sample = stroke.samples()[i];
461     features->push_back(sample.major_radius);
462     features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
463                                                    : sample.minor_radius);
464     float distance = 0;
465     if (i != 0) {
466       distance = EuclideanDistance(stroke.samples()[i - 1].point, sample.point);
467     }
468     features->push_back(distance);
469     features->push_back(sample.edge);
470     features->push_back(1.0);  // existence.
471   }
472   const int padding = model_->config().max_sample_count - size;
473   DCHECK_GE(padding, 0);
474 
475   for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
476     features->push_back(0.0);
477   }
478   // "fill proportion."
479   features->push_back(static_cast<float>(size) /
480                       model_->config().max_sample_count);
481   features->push_back(EuclideanDistance(stroke.samples().front().point,
482                                         stroke.samples().back().point));
483 
484   // Start sequence number. 0 is min.
485   uint32_t samples_seen = stroke.samples_seen();
486   if (samples_seen < model_->config().max_sample_count) {
487     features->push_back(0);
488   } else {
489     features->push_back(samples_seen - model_->config().max_sample_count);
490   }
491 }
492 
493 /**
494  * The flow here is similar to 'AppendFeatures' above, but we rely on the
495  * timing of the samples rather than on the explicit number / position of
496  * samples.
497  */
AppendResampledFeatures(const PalmFilterStroke & stroke,std::vector<float> * features) const498 void NeuralStylusPalmDetectionFilter::AppendResampledFeatures(
499     const PalmFilterStroke& stroke,
500     std::vector<float>* features) const {
501   size_t sample_count = 0;
502   const base::TimeTicks& first_time = stroke.samples()[0].time;
503   const base::TimeDelta& resample_period = *model_->config().resample_period;
504   const base::TimeDelta max_duration =
505       model_->config().GetEquivalentDuration(model_->config().max_sample_count);
506   for (auto time = first_time; (time - first_time) <= max_duration &&
507                                time <= stroke.samples().back().time;
508        time += resample_period) {
509     sample_count++;
510     const PalmFilterSample& sample = stroke.GetSampleAt(time);
511     features->push_back(sample.major_radius);
512     features->push_back(sample.minor_radius <= 0.0 ? sample.major_radius
513                                                    : sample.minor_radius);
514     float distance = 0;
515     if (time != first_time) {
516       distance = EuclideanDistance(
517           stroke.GetSampleAt(time - resample_period).point, sample.point);
518     }
519     features->push_back(distance);
520     features->push_back(sample.edge);
521     features->push_back(1.0);  // existence.
522   }
523   const int padding = model_->config().max_sample_count - sample_count;
524   DCHECK_GE(padding, 0);
525 
526   for (int i = 0; i < padding * kFeaturesPerSample; ++i) {
527     features->push_back(0.0);
528   }
529   // "fill proportion."
530   features->push_back(static_cast<float>(sample_count) /
531                       model_->config().max_sample_count);
532   features->push_back(EuclideanDistance(stroke.samples().front().point,
533                                         stroke.samples().back().point));
534 
535   // Start sequence number. 0 is min.
536   uint32_t samples_seen =
537       (stroke.Duration() / (*model_->config().resample_period)) + 1;
538   if (samples_seen < model_->config().max_sample_count) {
539     features->push_back(0);
540   } else {
541     features->push_back(samples_seen - model_->config().max_sample_count);
542   }
543 }
544 
AppendFeaturesAsNeighbor(const PalmFilterStroke & stroke,float distance,std::vector<float> * features) const545 void NeuralStylusPalmDetectionFilter::AppendFeaturesAsNeighbor(
546     const PalmFilterStroke& stroke,
547     float distance,
548     std::vector<float>* features) const {
549   features->push_back(1);  // existence.
550   features->push_back(distance);
551   AppendFeatures(stroke, features);
552 }
553 
554 const int NeuralStylusPalmDetectionFilter::kExtraFeaturesForNeighbor = 2;
555 const int NeuralStylusPalmDetectionFilter::kFeaturesPerSample = 5;
556 
557 const char NeuralStylusPalmDetectionFilter::kFilterName[] =
558     "NeuralStylusPalmDetectionFilter";
FilterNameForTesting() const559 std::string NeuralStylusPalmDetectionFilter::FilterNameForTesting() const {
560   return kFilterName;
561 }
562 
563 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
564 bool NeuralStylusPalmDetectionFilter::
CompatibleWithNeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo)565     CompatibleWithNeuralStylusPalmDetectionFilter(
566         const EventDeviceInfo& devinfo) {
567   return NeuralStylusPalmDetectionFilter::
568       CompatibleWithNeuralStylusPalmDetectionFilter(
569           devinfo, base::CommandLine::ForCurrentProcess()->GetSwitchValueASCII(
570                        kOzoneNNPalmSwitchName));
571 }
572 
573 bool NeuralStylusPalmDetectionFilter::
CompatibleWithNeuralStylusPalmDetectionFilter(const EventDeviceInfo & devinfo,const std::string & ozone_params_switch_string)574     CompatibleWithNeuralStylusPalmDetectionFilter(
575         const EventDeviceInfo& devinfo,
576         const std::string& ozone_params_switch_string) {
577   if (devinfo.HasStylus()) {
578     return false;
579   }
580   // Though we like having abs_mt_touch_minor, we don't need it.
581   auto code_check = [&devinfo](int code) {
582     if (!devinfo.HasAbsEvent(code)) {
583       return false;
584     }
585     const auto absinfo = devinfo.GetAbsInfoByCode(code);
586     // Ensure minimum is 0, maximum is greater than the minimum.
587     if (absinfo.minimum != 0 || absinfo.maximum <= absinfo.minimum) {
588       return false;
589     }
590     return true;
591   };
592 
593   static constexpr int kRequiredAbsMtCodes[] = {
594       ABS_MT_POSITION_X, ABS_MT_POSITION_Y, ABS_MT_TOUCH_MAJOR};
595   if (!std::all_of(std::begin(kRequiredAbsMtCodes),
596                    std::end(kRequiredAbsMtCodes), code_check)) {
597     return false;
598   }
599 
600   // Optionally, we use touch_minor if it's around, so check it's good if it
601   // is present.
602   if (devinfo.HasAbsEvent(ABS_MT_TOUCH_MINOR) &&
603       !code_check(ABS_MT_TOUCH_MINOR)) {
604     return false;
605   }
606   // Only work with internal touchscreens.
607   if (devinfo.device_type() != INPUT_DEVICE_INTERNAL) {
608     return false;
609   }
610 
611   // Check the switch string.
612 
613   absl::optional<base::Value> value =
614       base::JSONReader::Read(ozone_params_switch_string);
615   if (value != absl::nullopt && !ozone_params_switch_string.empty()) {
616     if (!value->is_dict()) {
617       return false;
618     }
619     // If the key isn't set, default to false.
620     if (value->FindKey(kOzoneNNPalmTouchCompatibleProperty) == nullptr) {
621       return false;
622     }
623     std::string* touch_string_val =
624         value->FindStringKey(kOzoneNNPalmTouchCompatibleProperty);
625     if (touch_string_val != nullptr) {
626       if (*touch_string_val == "false") {
627         return false;
628       } else if (*touch_string_val == "true") {
629         return true;
630       } else {
631         LOG(DFATAL) << "Unexpected value for nnpalm touch compatible. expected "
632                        "\"true\" or \"false\" . Got: "
633                     << *touch_string_val;
634       }
635     }
636   }
637   return true;
638 }
639 #endif
640 
EraseOldStrokes(base::TimeTicks time)641 void NeuralStylusPalmDetectionFilter::EraseOldStrokes(base::TimeTicks time) {
642   const base::TimeDelta max_age = model_->config().max_dead_neighbor_time;
643   for (auto it = strokes_.begin(); it != strokes_.end();) {
644     DCHECK(!it->second.samples().empty());
645     const base::TimeTicks most_recent_sample_time =
646         it->second.samples().back().time;
647     const auto age = time - most_recent_sample_time;
648     if (age > max_age) {
649       it = strokes_.erase(it);
650     } else {
651       ++it;
652     }
653   }
654 
655   // If the blank time is more than max_blank_time, starts a new session.
656   if (time - previous_report_time_ > model_->config().max_blank_time) {
657     tracking_ids_count_within_session_ = 0;
658     active_tracking_ids_.clear();
659   }
660   previous_report_time_ = time;
661 }
662 
addLinePrefix(std::string str,const std::string & prefix)663 static std::string addLinePrefix(std::string str, const std::string& prefix) {
664   std::stringstream ss;
665   bool newLineStarted = true;
666   for (const auto& ch : str) {
667     if (newLineStarted) {
668       ss << prefix;
669       newLineStarted = false;
670     }
671     if (ch == '\n') {
672       newLineStarted = true;
673     }
674     ss << ch;
675   }
676   return ss.str();
677 }
678 
operator <<(std::ostream & out,const NeuralStylusPalmDetectionFilter & filter)679 std::ostream& operator<<(std::ostream& out,
680                          const NeuralStylusPalmDetectionFilter& filter) {
681   out << "NeuralStylusPalmDetectionFilter(\n";
682   out << "  is_palm_ = " << filter.is_palm_ << "\n";
683   out << "  is_delay_ = " << filter.is_delay_ << "\n";
684   out << "  strokes_ =\n";
685   std::stringstream strokes;
686   strokes << filter.strokes_;
687   out << addLinePrefix(strokes.str(), "    ") << "\n";
688   out << "  previous_report_time_ = " << filter.previous_report_time_ << "\n";
689   out << "  active_tracking_ids_ = " << filter.active_tracking_ids_ << "\n";
690   out << "  tracking_ids_count_within_session_ = "
691       << filter.tracking_ids_count_within_session_ << "\n";
692   out << "  tracking_ids = [";
693   for (int i = 0; i < kNumTouchEvdevSlots; i++) {
694     out << filter.tracking_ids_[i] << ", ";
695   }
696   out << "]\n";
697 
698   out << "  palm_filter_dev_info_ = " << filter.palm_filter_dev_info_ << "\n";
699   out << ")\n";
700 
701   return out;
702 }
703 
704 }  // namespace ui
705