xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/android/test/jni/object_tracking/object_tracker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef __RENDER_OPENGL__
17 #include <GLES/gl.h>
18 #include <GLES/glext.h>
19 #endif
20 
21 #include <cinttypes>
22 #include <map>
23 #include <string>
24 
25 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
26 #include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h"
27 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h"
28 #include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h"
29 #include "tensorflow/tools/android/test/jni/object_tracking/image.h"
30 #include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h"
31 #include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h"
32 #include "tensorflow/tools/android/test/jni/object_tracking/logging.h"
33 #include "tensorflow/tools/android/test/jni/object_tracking/object_detector.h"
34 #include "tensorflow/tools/android/test/jni/object_tracking/object_tracker.h"
35 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h"
36 #include "tensorflow/tools/android/test/jni/object_tracking/time_log.h"
37 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h"
38 
39 namespace tf_tracking {
40 
ObjectTracker(const TrackerConfig * const config,ObjectDetectorBase * const detector)41 ObjectTracker::ObjectTracker(const TrackerConfig* const config,
42                              ObjectDetectorBase* const detector)
43     : config_(config),
44       frame_width_(config->image_size.width),
45       frame_height_(config->image_size.height),
46       curr_time_(0),
47       num_frames_(0),
48       flow_cache_(&config->flow_config),
49       keypoint_detector_(&config->keypoint_detector_config),
50       curr_num_frame_pairs_(0),
51       first_frame_index_(0),
52       frame1_(new ImageData(frame_width_, frame_height_)),
53       frame2_(new ImageData(frame_width_, frame_height_)),
54       detector_(detector),
55       num_detected_(0) {
56   for (int i = 0; i < kNumFrames; ++i) {
57     frame_pairs_[i].Init(-1, -1);
58   }
59 }
60 
61 
~ObjectTracker()62 ObjectTracker::~ObjectTracker() {
63   for (TrackedObjectMap::iterator iter = objects_.begin();
64        iter != objects_.end(); iter++) {
65     TrackedObject* object = iter->second;
66     SAFE_DELETE(object);
67   }
68 }
69 
70 
71 // Finds the correspondences for all the points in the current pair of frames.
72 // Stores the results in the given FramePair.
FindCorrespondences(FramePair * const frame_pair) const73 void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
74   // Keypoints aren't found until they're found.
75   memset(frame_pair->optical_flow_found_keypoint_, false,
76          sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
77   TimeLog("Cleared old found keypoints");
78 
79   int num_keypoints_found = 0;
80 
81   // For every keypoint...
82   for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
83     Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
84     Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
85 
86     if (flow_cache_.FindNewPositionOfPoint(
87         keypoint1->pos_.x, keypoint1->pos_.y,
88         &keypoint2->pos_.x, &keypoint2->pos_.y)) {
89       frame_pair->optical_flow_found_keypoint_[i_feat] = true;
90       ++num_keypoints_found;
91     }
92   }
93 
94   TimeLog("Found correspondences");
95 
96   LOGV("Found %d of %d keypoint correspondences",
97        num_keypoints_found, frame_pair->number_of_keypoints_);
98 }
99 
NextFrame(const uint8_t * const new_frame,const uint8_t * const uv_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)100 void ObjectTracker::NextFrame(const uint8_t* const new_frame,
101                               const uint8_t* const uv_frame,
102                               const int64_t timestamp,
103                               const float* const alignment_matrix_2x3) {
104   IncrementFrameIndex();
105   LOGV("Received frame %d", num_frames_);
106 
107   FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
108   curr_change->Init(curr_time_, timestamp);
109 
110   CHECK_ALWAYS(curr_time_ < timestamp,
111                "Timestamp must monotonically increase! Went from %" PRId64
112                " to %" PRId64 " on frame %d.",
113                curr_time_, timestamp, num_frames_);
114 
115   curr_time_ = timestamp;
116 
117   // Swap the frames.
118   frame1_.swap(frame2_);
119 
120   frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
121 
122   if (detector_.get() != NULL) {
123     detector_->SetImageData(frame2_.get());
124   }
125 
126   flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
127 
128   if (num_frames_ == 1) {
129     // This must be the first frame, so abort.
130     return;
131   }
132 
133   if (config_->always_track || objects_.size() > 0) {
134     LOGV("Tracking %zu targets", objects_.size());
135     ComputeKeypoints(true);
136     TimeLog("Keypoints computed!");
137 
138     FindCorrespondences(curr_change);
139     TimeLog("Flow computed!");
140 
141     TrackObjects();
142   }
143   TimeLog("Targets tracked!");
144 
145   if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
146     DetectTargets();
147   }
148   TimeLog("Detected objects.");
149 }
150 
MaybeAddObject(const std::string & id,const Image<uint8_t> & source_image,const BoundingBox & bounding_box,const ObjectModelBase * object_model)151 TrackedObject* ObjectTracker::MaybeAddObject(
152     const std::string& id, const Image<uint8_t>& source_image,
153     const BoundingBox& bounding_box, const ObjectModelBase* object_model) {
154   // Train the detector if this is a new object.
155   if (objects_.find(id) != objects_.end()) {
156     return objects_[id];
157   }
158 
159   // Need to get a non-const version of the model, or create a new one if it
160   // wasn't given.
161   ObjectModelBase* model = NULL;
162   if (detector_ != NULL) {
163     // If a detector is registered, then this new object must have a model.
164     CHECK_ALWAYS(object_model != NULL, "No model given!");
165     model = detector_->CreateObjectModel(object_model->GetName());
166   }
167   TrackedObject* const object =
168       new TrackedObject(id, source_image, bounding_box, model);
169 
170   objects_[id] = object;
171   return object;
172 }
173 
RegisterNewObjectWithAppearance(const std::string & id,const uint8_t * const new_frame,const BoundingBox & bounding_box)174 void ObjectTracker::RegisterNewObjectWithAppearance(
175     const std::string& id, const uint8_t* const new_frame,
176     const BoundingBox& bounding_box) {
177   ObjectModelBase* object_model = NULL;
178 
179   Image<uint8_t> image(frame_width_, frame_height_);
180   image.FromArray(new_frame, frame_width_, 1);
181 
182   if (detector_ != NULL) {
183     object_model = detector_->CreateObjectModel(id);
184     CHECK_ALWAYS(object_model != NULL, "Null object model!");
185 
186     const IntegralImage integral_image(image);
187     object_model->TrackStep(bounding_box, image, integral_image, true);
188   }
189 
190   // Create an object at this position.
191   CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
192   if (objects_.find(id) == objects_.end()) {
193     TrackedObject* const object =
194         MaybeAddObject(id, image, bounding_box, object_model);
195     CHECK_ALWAYS(object != NULL, "Object not created!");
196   }
197 }
198 
SetPreviousPositionOfObject(const std::string & id,const BoundingBox & bounding_box,const int64_t timestamp)199 void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
200                                                 const BoundingBox& bounding_box,
201                                                 const int64_t timestamp) {
202   CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
203   CHECK_ALWAYS(timestamp <= curr_time_,
204                "Timestamp too great! %" PRId64 " vs %" PRId64, timestamp,
205                curr_time_);
206 
207   TrackedObject* const object = GetObject(id);
208 
209   // Track this bounding box from the past to the current time.
210   const BoundingBox current_position = TrackBox(bounding_box, timestamp);
211 
212   object->UpdatePosition(current_position, curr_time_, *frame2_, false);
213 
214   VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
215           << std::endl;
216 }
217 
218 
SetCurrentPositionOfObject(const std::string & id,const BoundingBox & bounding_box)219 void ObjectTracker::SetCurrentPositionOfObject(
220     const std::string& id, const BoundingBox& bounding_box) {
221   SetPreviousPositionOfObject(id, bounding_box, curr_time_);
222 }
223 
224 
ForgetTarget(const std::string & id)225 void ObjectTracker::ForgetTarget(const std::string& id) {
226   LOGV("Forgetting object %s", id.c_str());
227   TrackedObject* const object = GetObject(id);
228   delete object;
229   objects_.erase(id);
230 
231   if (detector_ != NULL) {
232     detector_->DeleteObjectModel(id);
233   }
234 }
235 
GetKeypointsPacked(uint16_t * const out_data,const float scale) const236 int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data,
237                                       const float scale) const {
238   const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
239   uint16_t* curr_data = out_data;
240   int num_keypoints = 0;
241 
242   for (int i = 0; i < change.number_of_keypoints_; ++i) {
243     if (change.optical_flow_found_keypoint_[i]) {
244       ++num_keypoints;
245       const Point2f& point1 = change.frame1_keypoints_[i].pos_;
246       *curr_data++ = RealToFixed115(point1.x * scale);
247       *curr_data++ = RealToFixed115(point1.y * scale);
248 
249       const Point2f& point2 = change.frame2_keypoints_[i].pos_;
250       *curr_data++ = RealToFixed115(point2.x * scale);
251       *curr_data++ = RealToFixed115(point2.y * scale);
252     }
253   }
254 
255   return num_keypoints;
256 }
257 
258 
GetKeypoints(const bool only_found,float * const out_data) const259 int ObjectTracker::GetKeypoints(const bool only_found,
260                                 float* const out_data) const {
261   int curr_keypoint = 0;
262   const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
263 
264   for (int i = 0; i < change.number_of_keypoints_; ++i) {
265     if (!only_found || change.optical_flow_found_keypoint_[i]) {
266       const int base = curr_keypoint * kKeypointStep;
267       out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
268       out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
269 
270       out_data[base + 2] =
271           change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
272       out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
273       out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
274 
275       out_data[base + 5] = change.frame1_keypoints_[i].score_;
276       out_data[base + 6] = change.frame1_keypoints_[i].type_;
277       ++curr_keypoint;
278     }
279   }
280 
281   LOGV("Got %d keypoints.", curr_keypoint);
282 
283   return curr_keypoint;
284 }
285 
286 
TrackBox(const BoundingBox & region,const FramePair & frame_pair) const287 BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
288                                     const FramePair& frame_pair) const {
289   float translation_x;
290   float translation_y;
291 
292   float scale_x;
293   float scale_y;
294 
295   BoundingBox tracked_box(region);
296   frame_pair.AdjustBox(
297       tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
298 
299   tracked_box.Shift(Point2f(translation_x, translation_y));
300 
301   if (scale_x > 0 && scale_y > 0) {
302     tracked_box.Scale(scale_x, scale_y);
303   }
304   return tracked_box;
305 }
306 
TrackBox(const BoundingBox & region,const int64_t timestamp) const307 BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
308                                     const int64_t timestamp) const {
309   CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
310   CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
311 
312   // Anything that ended before the requested timestamp is of no concern to us.
313   bool found_it = false;
314   int num_frames_back = -1;
315   for (int i = 0; i < curr_num_frame_pairs_; ++i) {
316     const FramePair& frame_pair =
317         frame_pairs_[GetNthIndexFromEnd(i)];
318 
319     if (frame_pair.end_time_ <= timestamp) {
320       num_frames_back = i - 1;
321 
322       if (num_frames_back > 0) {
323         LOGV("Went %d out of %d frames before finding frame. (index: %d)",
324              num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
325       }
326 
327       found_it = true;
328       break;
329     }
330   }
331 
332   if (!found_it) {
333     LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64,
334          frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
335              frame_pairs_[GetNthIndexFromStart(0)].end_time_,
336          frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
337   }
338 
339   // Loop over all the frames in the queue, tracking the accumulated delta
340   // of the point from frame to frame.  It's possible the point could
341   // go out of frame, but keep tracking as best we can, using points near
342   // the edge of the screen where it went out of bounds.
343   BoundingBox tracked_box(region);
344   for (int i = num_frames_back; i >= 0; --i) {
345     const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
346     SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
347     tracked_box = TrackBox(tracked_box, frame_pair);
348   }
349   return tracked_box;
350 }
351 
352 
353 // Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
354 // 3d transformation matrix.
Convert3x3To4x4(const float * const in_matrix,float * const out_matrix)355 inline void Convert3x3To4x4(
356     const float* const in_matrix, float* const out_matrix) {
357   // X
358   out_matrix[0] = in_matrix[0];
359   out_matrix[1] = in_matrix[3];
360   out_matrix[2] = 0.0f;
361   out_matrix[3] = 0.0f;
362 
363   // Y
364   out_matrix[4] = in_matrix[1];
365   out_matrix[5] = in_matrix[4];
366   out_matrix[6] = 0.0f;
367   out_matrix[7] = 0.0f;
368 
369   // Z
370   out_matrix[8] = 0.0f;
371   out_matrix[9] = 0.0f;
372   out_matrix[10] = 1.0f;
373   out_matrix[11] = 0.0f;
374 
375   // Translation
376   out_matrix[12] = in_matrix[2];
377   out_matrix[13] = in_matrix[5];
378   out_matrix[14] = 0.0f;
379   out_matrix[15] = 1.0f;
380 }
381 
382 
Draw(const int canvas_width,const int canvas_height,const float * const frame_to_canvas) const383 void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
384                          const float* const frame_to_canvas) const {
385 #ifdef __RENDER_OPENGL__
386   glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
387 
388   glMatrixMode(GL_PROJECTION);
389   glLoadIdentity();
390 
391   glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
392 
393   // To make Y go the right direction (0 at top of frame).
394   glScalef(1.0f, -1.0f, 1.0f);
395   glTranslatef(0.0f, -canvas_height, 0.0f);
396 
397   glMatrixMode(GL_MODELVIEW);
398   glLoadIdentity();
399 
400   glPushMatrix();
401 
402   // Apply the frame to canvas transformation.
403   static GLfloat transformation[16];
404   Convert3x3To4x4(frame_to_canvas, transformation);
405   glMultMatrixf(transformation);
406 
407   // Draw tracked object bounding boxes.
408   for (TrackedObjectMap::const_iterator iter = objects_.begin();
409     iter != objects_.end(); ++iter) {
410     TrackedObject* tracked_object = iter->second;
411     tracked_object->Draw();
412   }
413 
414   static const bool kRenderDebugPyramid = false;
415   if (kRenderDebugPyramid) {
416     glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
417     for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
418       Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
419     }
420   }
421 
422   static const bool kRenderDebugDerivative = false;
423   if (kRenderDebugDerivative) {
424     glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
425     for (int i = 0; i < kNumPyramidLevels; ++i) {
426       const Image<int32_t>& dx = *frame1_->GetSpatialX(i);
427       Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight());
428       for (int y = 0; y < dx.GetHeight(); ++y) {
429         const int32_t* dx_ptr = dx[y];
430         uint8_t* dst_ptr = render_image[y];
431         for (int x = 0; x < dx.GetWidth(); ++x) {
432           *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
433         }
434       }
435 
436       Sprite(render_image).Draw();
437     }
438   }
439 
440   if (detector_ != NULL) {
441     glDisable(GL_CULL_FACE);
442     detector_->Draw();
443   }
444   glPopMatrix();
445 #endif
446 }
447 
AddQuadrants(const BoundingBox & box,std::vector<BoundingBox> * boxes)448 static void AddQuadrants(const BoundingBox& box,
449                          std::vector<BoundingBox>* boxes) {
450   const Point2f center = box.GetCenter();
451 
452   float x1 = box.left_;
453   float x2 = center.x;
454   float x3 = box.right_;
455 
456   float y1 = box.top_;
457   float y2 = center.y;
458   float y3 = box.bottom_;
459 
460   // Upper left.
461   boxes->push_back(BoundingBox(x1, y1, x2, y2));
462 
463   // Upper right.
464   boxes->push_back(BoundingBox(x2, y1, x3, y2));
465 
466   // Bottom left.
467   boxes->push_back(BoundingBox(x1, y2, x2, y3));
468 
469   // Bottom right.
470   boxes->push_back(BoundingBox(x2, y2, x3, y3));
471 
472   // Whole thing.
473   boxes->push_back(box);
474 }
475 
ComputeKeypoints(const bool cached_ok)476 void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
477   const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
478   FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
479 
480   std::vector<BoundingBox> boxes;
481 
482   for (TrackedObjectMap::iterator object_iter = objects_.begin();
483        object_iter != objects_.end(); ++object_iter) {
484     BoundingBox box = object_iter->second->GetPosition();
485     box.Scale(config_->object_box_scale_factor_for_features,
486               config_->object_box_scale_factor_for_features);
487     AddQuadrants(box, &boxes);
488   }
489 
490   AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
491 
492   keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
493 }
494 
495 
496 // Given a vector of detections and a model, simply returns the Detection for
497 // that model with the highest correlation.
GetBestObjectForDetection(const Detection & detection,TrackedObject ** match) const498 bool ObjectTracker::GetBestObjectForDetection(
499     const Detection& detection, TrackedObject** match) const {
500   TrackedObject* best_match = NULL;
501   float best_overlap = -FLT_MAX;
502 
503   LOGV("Looking for matches in %zu objects!", objects_.size());
504   for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
505       object_iter != objects_.end(); ++object_iter) {
506     TrackedObject* const tracked_object = object_iter->second;
507 
508     const float overlap = tracked_object->GetPosition().PascalScore(
509         detection.GetObjectBoundingBox());
510 
511     if (!detector_->AllowSpontaneousDetections() &&
512         (detection.GetObjectModel() != tracked_object->GetModel())) {
513       if (overlap > 0.0f) {
514         return false;
515       }
516       continue;
517     }
518 
519     const float jump_distance =
520         (tracked_object->GetPosition().GetCenter() -
521          detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
522 
523     const float allowed_distance =
524         tracked_object->GetAllowableDistanceSquared();
525 
526     LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
527          jump_distance, allowed_distance, overlap);
528 
529     // TODO(andrewharp): No need to do this verification twice, eliminate
530     // one of the score checks (the other being in OnDetection).
531     if (jump_distance < allowed_distance &&
532         overlap > best_overlap &&
533         tracked_object->GetMatchScore() + kMatchScoreBuffer <
534         detection.GetMatchScore()) {
535       best_match = tracked_object;
536       best_overlap = overlap;
537     } else if (overlap > 0.0f) {
538       return false;
539     }
540   }
541 
542   *match = best_match;
543   return true;
544 }
545 
546 
ProcessDetections(std::vector<Detection> * const detections)547 void ObjectTracker::ProcessDetections(
548     std::vector<Detection>* const detections) {
549   LOGV("Initial detection done, iterating over %zu detections now.",
550        detections->size());
551 
552   const bool spontaneous_detections_allowed =
553       detector_->AllowSpontaneousDetections();
554   for (std::vector<Detection>::const_iterator it = detections->begin();
555       it != detections->end(); ++it) {
556     const Detection& detection = *it;
557     SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
558           "Frame does not contain bounding box!");
559 
560     TrackedObject* best_match = NULL;
561 
562     const bool no_collisions =
563         GetBestObjectForDetection(detection, &best_match);
564 
565     // Need to get a non-const version of the model, or create a new one if it
566     // wasn't given.
567     ObjectModelBase* model =
568         const_cast<ObjectModelBase*>(detection.GetObjectModel());
569 
570     if (best_match != NULL) {
571       if (model != best_match->GetModel()) {
572         CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
573             "Model for object changed but spontaneous detections not allowed!");
574       }
575       best_match->OnDetection(model,
576                               detection.GetObjectBoundingBox(),
577                               detection.GetMatchScore(),
578                               curr_time_, *frame2_);
579     } else if (no_collisions && spontaneous_detections_allowed) {
580       if (detection.GetMatchScore() > kMinimumMatchScore) {
581         LOGV("No match, adding it!");
582         const ObjectModelBase* model = detection.GetObjectModel();
583         std::ostringstream ss;
584         // TODO(andrewharp): Generate this in a more general fashion.
585         ss << "hand_" << num_detected_++;
586         std::string object_name = ss.str();
587         MaybeAddObject(object_name, *frame2_->GetImage(),
588                        detection.GetObjectBoundingBox(), model);
589       }
590     }
591   }
592 }
593 
594 
DetectTargets()595 void ObjectTracker::DetectTargets() {
596   // Detect all object model types that we're currently tracking.
597   std::vector<const ObjectModelBase*> object_models;
598   detector_->GetObjectModels(&object_models);
599   if (object_models.size() == 0) {
600     LOGV("No objects to search for, aborting.");
601     return;
602   }
603 
604   LOGV("Trying to detect %zu models", object_models.size());
605 
606   LOGV("Creating test vector!");
607   std::vector<BoundingSquare> positions;
608 
609   for (TrackedObjectMap::iterator object_iter = objects_.begin();
610       object_iter != objects_.end(); ++object_iter) {
611     TrackedObject* const tracked_object = object_iter->second;
612 
613 #if DEBUG_PREDATOR
614   positions.push_back(GetCenteredSquare(
615       frame2_->GetImage()->GetContainingBox(), 32.0f));
616 #else
617     const BoundingBox& position = tracked_object->GetPosition();
618 
619     const float square_size = MAX(
620         kScanMinSquareSize / (kLastKnownPositionScaleFactor *
621         kLastKnownPositionScaleFactor),
622         MIN(position.GetWidth(),
623         position.GetHeight())) / kLastKnownPositionScaleFactor;
624 
625     FillWithSquares(frame2_->GetImage()->GetContainingBox(),
626                     tracked_object->GetPosition(),
627                     square_size,
628                     kScanMinSquareSize,
629                     kLastKnownPositionScaleFactor,
630                     &positions);
631   }
632 #endif
633 
634   LOGV("Created test vector!");
635 
636   std::vector<Detection> detections;
637   LOGV("Detecting!");
638   detector_->Detect(positions, &detections);
639   LOGV("Found %zu detections", detections.size());
640 
641   TimeLog("Finished detection.");
642 
643   ProcessDetections(&detections);
644 
645   TimeLog("iterated over detections");
646 
647   LOGV("Done detecting!");
648 }
649 
650 
TrackObjects()651 void ObjectTracker::TrackObjects() {
652   // TODO(andrewharp): Correlation should be allowed to remove objects too.
653   const bool automatic_removal_allowed = detector_.get() != NULL ?
654       detector_->AllowSpontaneousDetections() : false;
655 
656   LOGV("Tracking %zu objects!", objects_.size());
657   std::vector<std::string> dead_objects;
658   for (TrackedObjectMap::iterator iter = objects_.begin();
659        iter != objects_.end(); iter++) {
660     TrackedObject* object = iter->second;
661     const BoundingBox tracked_position = TrackBox(
662         object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
663     object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
664 
665     if (automatic_removal_allowed &&
666         object->GetNumConsecutiveFramesBelowThreshold() >
667         kMaxNumDetectionFailures * 5) {
668       dead_objects.push_back(iter->first);
669     }
670   }
671 
672   if (detector_ != NULL && automatic_removal_allowed) {
673     for (std::vector<std::string>::iterator iter = dead_objects.begin();
674          iter != dead_objects.end(); iter++) {
675       LOGE("Removing object! %s", iter->c_str());
676       ForgetTarget(*iter);
677     }
678   }
679   TimeLog("Tracked all objects.");
680 
681   LOGV("%zu objects tracked!", objects_.size());
682 }
683 
684 }  // namespace tf_tracking
685