1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifdef __RENDER_OPENGL__
17 #include <GLES/gl.h>
18 #include <GLES/glext.h>
19 #endif
20
21 #include <cinttypes>
22 #include <map>
23 #include <string>
24
25 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
26 #include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h"
27 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h"
28 #include "tensorflow/tools/android/test/jni/object_tracking/image-inl.h"
29 #include "tensorflow/tools/android/test/jni/object_tracking/image.h"
30 #include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h"
31 #include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h"
32 #include "tensorflow/tools/android/test/jni/object_tracking/logging.h"
33 #include "tensorflow/tools/android/test/jni/object_tracking/object_detector.h"
34 #include "tensorflow/tools/android/test/jni/object_tracking/object_tracker.h"
35 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h"
36 #include "tensorflow/tools/android/test/jni/object_tracking/time_log.h"
37 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h"
38
39 namespace tf_tracking {
40
ObjectTracker(const TrackerConfig * const config,ObjectDetectorBase * const detector)41 ObjectTracker::ObjectTracker(const TrackerConfig* const config,
42 ObjectDetectorBase* const detector)
43 : config_(config),
44 frame_width_(config->image_size.width),
45 frame_height_(config->image_size.height),
46 curr_time_(0),
47 num_frames_(0),
48 flow_cache_(&config->flow_config),
49 keypoint_detector_(&config->keypoint_detector_config),
50 curr_num_frame_pairs_(0),
51 first_frame_index_(0),
52 frame1_(new ImageData(frame_width_, frame_height_)),
53 frame2_(new ImageData(frame_width_, frame_height_)),
54 detector_(detector),
55 num_detected_(0) {
56 for (int i = 0; i < kNumFrames; ++i) {
57 frame_pairs_[i].Init(-1, -1);
58 }
59 }
60
61
~ObjectTracker()62 ObjectTracker::~ObjectTracker() {
63 for (TrackedObjectMap::iterator iter = objects_.begin();
64 iter != objects_.end(); iter++) {
65 TrackedObject* object = iter->second;
66 SAFE_DELETE(object);
67 }
68 }
69
70
71 // Finds the correspondences for all the points in the current pair of frames.
72 // Stores the results in the given FramePair.
FindCorrespondences(FramePair * const frame_pair) const73 void ObjectTracker::FindCorrespondences(FramePair* const frame_pair) const {
74 // Keypoints aren't found until they're found.
75 memset(frame_pair->optical_flow_found_keypoint_, false,
76 sizeof(*frame_pair->optical_flow_found_keypoint_) * kMaxKeypoints);
77 TimeLog("Cleared old found keypoints");
78
79 int num_keypoints_found = 0;
80
81 // For every keypoint...
82 for (int i_feat = 0; i_feat < frame_pair->number_of_keypoints_; ++i_feat) {
83 Keypoint* const keypoint1 = frame_pair->frame1_keypoints_ + i_feat;
84 Keypoint* const keypoint2 = frame_pair->frame2_keypoints_ + i_feat;
85
86 if (flow_cache_.FindNewPositionOfPoint(
87 keypoint1->pos_.x, keypoint1->pos_.y,
88 &keypoint2->pos_.x, &keypoint2->pos_.y)) {
89 frame_pair->optical_flow_found_keypoint_[i_feat] = true;
90 ++num_keypoints_found;
91 }
92 }
93
94 TimeLog("Found correspondences");
95
96 LOGV("Found %d of %d keypoint correspondences",
97 num_keypoints_found, frame_pair->number_of_keypoints_);
98 }
99
NextFrame(const uint8_t * const new_frame,const uint8_t * const uv_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)100 void ObjectTracker::NextFrame(const uint8_t* const new_frame,
101 const uint8_t* const uv_frame,
102 const int64_t timestamp,
103 const float* const alignment_matrix_2x3) {
104 IncrementFrameIndex();
105 LOGV("Received frame %d", num_frames_);
106
107 FramePair* const curr_change = frame_pairs_ + GetNthIndexFromEnd(0);
108 curr_change->Init(curr_time_, timestamp);
109
110 CHECK_ALWAYS(curr_time_ < timestamp,
111 "Timestamp must monotonically increase! Went from %" PRId64
112 " to %" PRId64 " on frame %d.",
113 curr_time_, timestamp, num_frames_);
114
115 curr_time_ = timestamp;
116
117 // Swap the frames.
118 frame1_.swap(frame2_);
119
120 frame2_->SetData(new_frame, uv_frame, frame_width_, timestamp, 1);
121
122 if (detector_.get() != NULL) {
123 detector_->SetImageData(frame2_.get());
124 }
125
126 flow_cache_.NextFrame(frame2_.get(), alignment_matrix_2x3);
127
128 if (num_frames_ == 1) {
129 // This must be the first frame, so abort.
130 return;
131 }
132
133 if (config_->always_track || objects_.size() > 0) {
134 LOGV("Tracking %zu targets", objects_.size());
135 ComputeKeypoints(true);
136 TimeLog("Keypoints computed!");
137
138 FindCorrespondences(curr_change);
139 TimeLog("Flow computed!");
140
141 TrackObjects();
142 }
143 TimeLog("Targets tracked!");
144
145 if (detector_.get() != NULL && num_frames_ % kDetectEveryNFrames == 0) {
146 DetectTargets();
147 }
148 TimeLog("Detected objects.");
149 }
150
MaybeAddObject(const std::string & id,const Image<uint8_t> & source_image,const BoundingBox & bounding_box,const ObjectModelBase * object_model)151 TrackedObject* ObjectTracker::MaybeAddObject(
152 const std::string& id, const Image<uint8_t>& source_image,
153 const BoundingBox& bounding_box, const ObjectModelBase* object_model) {
154 // Train the detector if this is a new object.
155 if (objects_.find(id) != objects_.end()) {
156 return objects_[id];
157 }
158
159 // Need to get a non-const version of the model, or create a new one if it
160 // wasn't given.
161 ObjectModelBase* model = NULL;
162 if (detector_ != NULL) {
163 // If a detector is registered, then this new object must have a model.
164 CHECK_ALWAYS(object_model != NULL, "No model given!");
165 model = detector_->CreateObjectModel(object_model->GetName());
166 }
167 TrackedObject* const object =
168 new TrackedObject(id, source_image, bounding_box, model);
169
170 objects_[id] = object;
171 return object;
172 }
173
RegisterNewObjectWithAppearance(const std::string & id,const uint8_t * const new_frame,const BoundingBox & bounding_box)174 void ObjectTracker::RegisterNewObjectWithAppearance(
175 const std::string& id, const uint8_t* const new_frame,
176 const BoundingBox& bounding_box) {
177 ObjectModelBase* object_model = NULL;
178
179 Image<uint8_t> image(frame_width_, frame_height_);
180 image.FromArray(new_frame, frame_width_, 1);
181
182 if (detector_ != NULL) {
183 object_model = detector_->CreateObjectModel(id);
184 CHECK_ALWAYS(object_model != NULL, "Null object model!");
185
186 const IntegralImage integral_image(image);
187 object_model->TrackStep(bounding_box, image, integral_image, true);
188 }
189
190 // Create an object at this position.
191 CHECK_ALWAYS(!HaveObject(id), "Already have this object!");
192 if (objects_.find(id) == objects_.end()) {
193 TrackedObject* const object =
194 MaybeAddObject(id, image, bounding_box, object_model);
195 CHECK_ALWAYS(object != NULL, "Object not created!");
196 }
197 }
198
SetPreviousPositionOfObject(const std::string & id,const BoundingBox & bounding_box,const int64_t timestamp)199 void ObjectTracker::SetPreviousPositionOfObject(const std::string& id,
200 const BoundingBox& bounding_box,
201 const int64_t timestamp) {
202 CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
203 CHECK_ALWAYS(timestamp <= curr_time_,
204 "Timestamp too great! %" PRId64 " vs %" PRId64, timestamp,
205 curr_time_);
206
207 TrackedObject* const object = GetObject(id);
208
209 // Track this bounding box from the past to the current time.
210 const BoundingBox current_position = TrackBox(bounding_box, timestamp);
211
212 object->UpdatePosition(current_position, curr_time_, *frame2_, false);
213
214 VLOG(2) << "Set tracked position for " << id << " to " << bounding_box
215 << std::endl;
216 }
217
218
SetCurrentPositionOfObject(const std::string & id,const BoundingBox & bounding_box)219 void ObjectTracker::SetCurrentPositionOfObject(
220 const std::string& id, const BoundingBox& bounding_box) {
221 SetPreviousPositionOfObject(id, bounding_box, curr_time_);
222 }
223
224
ForgetTarget(const std::string & id)225 void ObjectTracker::ForgetTarget(const std::string& id) {
226 LOGV("Forgetting object %s", id.c_str());
227 TrackedObject* const object = GetObject(id);
228 delete object;
229 objects_.erase(id);
230
231 if (detector_ != NULL) {
232 detector_->DeleteObjectModel(id);
233 }
234 }
235
GetKeypointsPacked(uint16_t * const out_data,const float scale) const236 int ObjectTracker::GetKeypointsPacked(uint16_t* const out_data,
237 const float scale) const {
238 const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
239 uint16_t* curr_data = out_data;
240 int num_keypoints = 0;
241
242 for (int i = 0; i < change.number_of_keypoints_; ++i) {
243 if (change.optical_flow_found_keypoint_[i]) {
244 ++num_keypoints;
245 const Point2f& point1 = change.frame1_keypoints_[i].pos_;
246 *curr_data++ = RealToFixed115(point1.x * scale);
247 *curr_data++ = RealToFixed115(point1.y * scale);
248
249 const Point2f& point2 = change.frame2_keypoints_[i].pos_;
250 *curr_data++ = RealToFixed115(point2.x * scale);
251 *curr_data++ = RealToFixed115(point2.y * scale);
252 }
253 }
254
255 return num_keypoints;
256 }
257
258
GetKeypoints(const bool only_found,float * const out_data) const259 int ObjectTracker::GetKeypoints(const bool only_found,
260 float* const out_data) const {
261 int curr_keypoint = 0;
262 const FramePair& change = frame_pairs_[GetNthIndexFromEnd(0)];
263
264 for (int i = 0; i < change.number_of_keypoints_; ++i) {
265 if (!only_found || change.optical_flow_found_keypoint_[i]) {
266 const int base = curr_keypoint * kKeypointStep;
267 out_data[base + 0] = change.frame1_keypoints_[i].pos_.x;
268 out_data[base + 1] = change.frame1_keypoints_[i].pos_.y;
269
270 out_data[base + 2] =
271 change.optical_flow_found_keypoint_[i] ? 1.0f : -1.0f;
272 out_data[base + 3] = change.frame2_keypoints_[i].pos_.x;
273 out_data[base + 4] = change.frame2_keypoints_[i].pos_.y;
274
275 out_data[base + 5] = change.frame1_keypoints_[i].score_;
276 out_data[base + 6] = change.frame1_keypoints_[i].type_;
277 ++curr_keypoint;
278 }
279 }
280
281 LOGV("Got %d keypoints.", curr_keypoint);
282
283 return curr_keypoint;
284 }
285
286
TrackBox(const BoundingBox & region,const FramePair & frame_pair) const287 BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
288 const FramePair& frame_pair) const {
289 float translation_x;
290 float translation_y;
291
292 float scale_x;
293 float scale_y;
294
295 BoundingBox tracked_box(region);
296 frame_pair.AdjustBox(
297 tracked_box, &translation_x, &translation_y, &scale_x, &scale_y);
298
299 tracked_box.Shift(Point2f(translation_x, translation_y));
300
301 if (scale_x > 0 && scale_y > 0) {
302 tracked_box.Scale(scale_x, scale_y);
303 }
304 return tracked_box;
305 }
306
TrackBox(const BoundingBox & region,const int64_t timestamp) const307 BoundingBox ObjectTracker::TrackBox(const BoundingBox& region,
308 const int64_t timestamp) const {
309 CHECK_ALWAYS(timestamp > 0, "Timestamp too low! %" PRId64, timestamp);
310 CHECK_ALWAYS(timestamp <= curr_time_, "Timestamp is in the future!");
311
312 // Anything that ended before the requested timestamp is of no concern to us.
313 bool found_it = false;
314 int num_frames_back = -1;
315 for (int i = 0; i < curr_num_frame_pairs_; ++i) {
316 const FramePair& frame_pair =
317 frame_pairs_[GetNthIndexFromEnd(i)];
318
319 if (frame_pair.end_time_ <= timestamp) {
320 num_frames_back = i - 1;
321
322 if (num_frames_back > 0) {
323 LOGV("Went %d out of %d frames before finding frame. (index: %d)",
324 num_frames_back, curr_num_frame_pairs_, GetNthIndexFromEnd(i));
325 }
326
327 found_it = true;
328 break;
329 }
330 }
331
332 if (!found_it) {
333 LOGW("History did not go back far enough! %" PRId64 " vs %" PRId64,
334 frame_pairs_[GetNthIndexFromEnd(0)].end_time_ -
335 frame_pairs_[GetNthIndexFromStart(0)].end_time_,
336 frame_pairs_[GetNthIndexFromEnd(0)].end_time_ - timestamp);
337 }
338
339 // Loop over all the frames in the queue, tracking the accumulated delta
340 // of the point from frame to frame. It's possible the point could
341 // go out of frame, but keep tracking as best we can, using points near
342 // the edge of the screen where it went out of bounds.
343 BoundingBox tracked_box(region);
344 for (int i = num_frames_back; i >= 0; --i) {
345 const FramePair& frame_pair = frame_pairs_[GetNthIndexFromEnd(i)];
346 SCHECK(frame_pair.end_time_ >= timestamp, "Frame timestamp was too early!");
347 tracked_box = TrackBox(tracked_box, frame_pair);
348 }
349 return tracked_box;
350 }
351
352
353 // Converts a row-major 3x3 2d transformation matrix to a column-major 4x4
354 // 3d transformation matrix.
Convert3x3To4x4(const float * const in_matrix,float * const out_matrix)355 inline void Convert3x3To4x4(
356 const float* const in_matrix, float* const out_matrix) {
357 // X
358 out_matrix[0] = in_matrix[0];
359 out_matrix[1] = in_matrix[3];
360 out_matrix[2] = 0.0f;
361 out_matrix[3] = 0.0f;
362
363 // Y
364 out_matrix[4] = in_matrix[1];
365 out_matrix[5] = in_matrix[4];
366 out_matrix[6] = 0.0f;
367 out_matrix[7] = 0.0f;
368
369 // Z
370 out_matrix[8] = 0.0f;
371 out_matrix[9] = 0.0f;
372 out_matrix[10] = 1.0f;
373 out_matrix[11] = 0.0f;
374
375 // Translation
376 out_matrix[12] = in_matrix[2];
377 out_matrix[13] = in_matrix[5];
378 out_matrix[14] = 0.0f;
379 out_matrix[15] = 1.0f;
380 }
381
382
Draw(const int canvas_width,const int canvas_height,const float * const frame_to_canvas) const383 void ObjectTracker::Draw(const int canvas_width, const int canvas_height,
384 const float* const frame_to_canvas) const {
385 #ifdef __RENDER_OPENGL__
386 glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT);
387
388 glMatrixMode(GL_PROJECTION);
389 glLoadIdentity();
390
391 glOrthof(0.0f, canvas_width, 0.0f, canvas_height, 0.0f, 1.0f);
392
393 // To make Y go the right direction (0 at top of frame).
394 glScalef(1.0f, -1.0f, 1.0f);
395 glTranslatef(0.0f, -canvas_height, 0.0f);
396
397 glMatrixMode(GL_MODELVIEW);
398 glLoadIdentity();
399
400 glPushMatrix();
401
402 // Apply the frame to canvas transformation.
403 static GLfloat transformation[16];
404 Convert3x3To4x4(frame_to_canvas, transformation);
405 glMultMatrixf(transformation);
406
407 // Draw tracked object bounding boxes.
408 for (TrackedObjectMap::const_iterator iter = objects_.begin();
409 iter != objects_.end(); ++iter) {
410 TrackedObject* tracked_object = iter->second;
411 tracked_object->Draw();
412 }
413
414 static const bool kRenderDebugPyramid = false;
415 if (kRenderDebugPyramid) {
416 glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
417 for (int i = 0; i < kNumPyramidLevels * 2; ++i) {
418 Sprite(*frame1_->GetPyramidSqrt2Level(i)).Draw();
419 }
420 }
421
422 static const bool kRenderDebugDerivative = false;
423 if (kRenderDebugDerivative) {
424 glColor4f(1.0f, 1.0f, 1.0f, 1.0f);
425 for (int i = 0; i < kNumPyramidLevels; ++i) {
426 const Image<int32_t>& dx = *frame1_->GetSpatialX(i);
427 Image<uint8_t> render_image(dx.GetWidth(), dx.GetHeight());
428 for (int y = 0; y < dx.GetHeight(); ++y) {
429 const int32_t* dx_ptr = dx[y];
430 uint8_t* dst_ptr = render_image[y];
431 for (int x = 0; x < dx.GetWidth(); ++x) {
432 *dst_ptr++ = Clip(-(*dx_ptr++), 0, 255);
433 }
434 }
435
436 Sprite(render_image).Draw();
437 }
438 }
439
440 if (detector_ != NULL) {
441 glDisable(GL_CULL_FACE);
442 detector_->Draw();
443 }
444 glPopMatrix();
445 #endif
446 }
447
AddQuadrants(const BoundingBox & box,std::vector<BoundingBox> * boxes)448 static void AddQuadrants(const BoundingBox& box,
449 std::vector<BoundingBox>* boxes) {
450 const Point2f center = box.GetCenter();
451
452 float x1 = box.left_;
453 float x2 = center.x;
454 float x3 = box.right_;
455
456 float y1 = box.top_;
457 float y2 = center.y;
458 float y3 = box.bottom_;
459
460 // Upper left.
461 boxes->push_back(BoundingBox(x1, y1, x2, y2));
462
463 // Upper right.
464 boxes->push_back(BoundingBox(x2, y1, x3, y2));
465
466 // Bottom left.
467 boxes->push_back(BoundingBox(x1, y2, x2, y3));
468
469 // Bottom right.
470 boxes->push_back(BoundingBox(x2, y2, x3, y3));
471
472 // Whole thing.
473 boxes->push_back(box);
474 }
475
ComputeKeypoints(const bool cached_ok)476 void ObjectTracker::ComputeKeypoints(const bool cached_ok) {
477 const FramePair& prev_change = frame_pairs_[GetNthIndexFromEnd(1)];
478 FramePair* const curr_change = &frame_pairs_[GetNthIndexFromEnd(0)];
479
480 std::vector<BoundingBox> boxes;
481
482 for (TrackedObjectMap::iterator object_iter = objects_.begin();
483 object_iter != objects_.end(); ++object_iter) {
484 BoundingBox box = object_iter->second->GetPosition();
485 box.Scale(config_->object_box_scale_factor_for_features,
486 config_->object_box_scale_factor_for_features);
487 AddQuadrants(box, &boxes);
488 }
489
490 AddQuadrants(frame1_->GetImage()->GetContainingBox(), &boxes);
491
492 keypoint_detector_.FindKeypoints(*frame1_, boxes, prev_change, curr_change);
493 }
494
495
496 // Given a vector of detections and a model, simply returns the Detection for
497 // that model with the highest correlation.
GetBestObjectForDetection(const Detection & detection,TrackedObject ** match) const498 bool ObjectTracker::GetBestObjectForDetection(
499 const Detection& detection, TrackedObject** match) const {
500 TrackedObject* best_match = NULL;
501 float best_overlap = -FLT_MAX;
502
503 LOGV("Looking for matches in %zu objects!", objects_.size());
504 for (TrackedObjectMap::const_iterator object_iter = objects_.begin();
505 object_iter != objects_.end(); ++object_iter) {
506 TrackedObject* const tracked_object = object_iter->second;
507
508 const float overlap = tracked_object->GetPosition().PascalScore(
509 detection.GetObjectBoundingBox());
510
511 if (!detector_->AllowSpontaneousDetections() &&
512 (detection.GetObjectModel() != tracked_object->GetModel())) {
513 if (overlap > 0.0f) {
514 return false;
515 }
516 continue;
517 }
518
519 const float jump_distance =
520 (tracked_object->GetPosition().GetCenter() -
521 detection.GetObjectBoundingBox().GetCenter()).LengthSquared();
522
523 const float allowed_distance =
524 tracked_object->GetAllowableDistanceSquared();
525
526 LOGV("Distance: %.2f, Allowed distance %.2f, Overlap: %.2f",
527 jump_distance, allowed_distance, overlap);
528
529 // TODO(andrewharp): No need to do this verification twice, eliminate
530 // one of the score checks (the other being in OnDetection).
531 if (jump_distance < allowed_distance &&
532 overlap > best_overlap &&
533 tracked_object->GetMatchScore() + kMatchScoreBuffer <
534 detection.GetMatchScore()) {
535 best_match = tracked_object;
536 best_overlap = overlap;
537 } else if (overlap > 0.0f) {
538 return false;
539 }
540 }
541
542 *match = best_match;
543 return true;
544 }
545
546
ProcessDetections(std::vector<Detection> * const detections)547 void ObjectTracker::ProcessDetections(
548 std::vector<Detection>* const detections) {
549 LOGV("Initial detection done, iterating over %zu detections now.",
550 detections->size());
551
552 const bool spontaneous_detections_allowed =
553 detector_->AllowSpontaneousDetections();
554 for (std::vector<Detection>::const_iterator it = detections->begin();
555 it != detections->end(); ++it) {
556 const Detection& detection = *it;
557 SCHECK(frame2_->GetImage()->Contains(detection.GetObjectBoundingBox()),
558 "Frame does not contain bounding box!");
559
560 TrackedObject* best_match = NULL;
561
562 const bool no_collisions =
563 GetBestObjectForDetection(detection, &best_match);
564
565 // Need to get a non-const version of the model, or create a new one if it
566 // wasn't given.
567 ObjectModelBase* model =
568 const_cast<ObjectModelBase*>(detection.GetObjectModel());
569
570 if (best_match != NULL) {
571 if (model != best_match->GetModel()) {
572 CHECK_ALWAYS(detector_->AllowSpontaneousDetections(),
573 "Model for object changed but spontaneous detections not allowed!");
574 }
575 best_match->OnDetection(model,
576 detection.GetObjectBoundingBox(),
577 detection.GetMatchScore(),
578 curr_time_, *frame2_);
579 } else if (no_collisions && spontaneous_detections_allowed) {
580 if (detection.GetMatchScore() > kMinimumMatchScore) {
581 LOGV("No match, adding it!");
582 const ObjectModelBase* model = detection.GetObjectModel();
583 std::ostringstream ss;
584 // TODO(andrewharp): Generate this in a more general fashion.
585 ss << "hand_" << num_detected_++;
586 std::string object_name = ss.str();
587 MaybeAddObject(object_name, *frame2_->GetImage(),
588 detection.GetObjectBoundingBox(), model);
589 }
590 }
591 }
592 }
593
594
DetectTargets()595 void ObjectTracker::DetectTargets() {
596 // Detect all object model types that we're currently tracking.
597 std::vector<const ObjectModelBase*> object_models;
598 detector_->GetObjectModels(&object_models);
599 if (object_models.size() == 0) {
600 LOGV("No objects to search for, aborting.");
601 return;
602 }
603
604 LOGV("Trying to detect %zu models", object_models.size());
605
606 LOGV("Creating test vector!");
607 std::vector<BoundingSquare> positions;
608
609 for (TrackedObjectMap::iterator object_iter = objects_.begin();
610 object_iter != objects_.end(); ++object_iter) {
611 TrackedObject* const tracked_object = object_iter->second;
612
613 #if DEBUG_PREDATOR
614 positions.push_back(GetCenteredSquare(
615 frame2_->GetImage()->GetContainingBox(), 32.0f));
616 #else
617 const BoundingBox& position = tracked_object->GetPosition();
618
619 const float square_size = MAX(
620 kScanMinSquareSize / (kLastKnownPositionScaleFactor *
621 kLastKnownPositionScaleFactor),
622 MIN(position.GetWidth(),
623 position.GetHeight())) / kLastKnownPositionScaleFactor;
624
625 FillWithSquares(frame2_->GetImage()->GetContainingBox(),
626 tracked_object->GetPosition(),
627 square_size,
628 kScanMinSquareSize,
629 kLastKnownPositionScaleFactor,
630 &positions);
631 }
632 #endif
633
634 LOGV("Created test vector!");
635
636 std::vector<Detection> detections;
637 LOGV("Detecting!");
638 detector_->Detect(positions, &detections);
639 LOGV("Found %zu detections", detections.size());
640
641 TimeLog("Finished detection.");
642
643 ProcessDetections(&detections);
644
645 TimeLog("iterated over detections");
646
647 LOGV("Done detecting!");
648 }
649
650
TrackObjects()651 void ObjectTracker::TrackObjects() {
652 // TODO(andrewharp): Correlation should be allowed to remove objects too.
653 const bool automatic_removal_allowed = detector_.get() != NULL ?
654 detector_->AllowSpontaneousDetections() : false;
655
656 LOGV("Tracking %zu objects!", objects_.size());
657 std::vector<std::string> dead_objects;
658 for (TrackedObjectMap::iterator iter = objects_.begin();
659 iter != objects_.end(); iter++) {
660 TrackedObject* object = iter->second;
661 const BoundingBox tracked_position = TrackBox(
662 object->GetPosition(), frame_pairs_[GetNthIndexFromEnd(0)]);
663 object->UpdatePosition(tracked_position, curr_time_, *frame2_, false);
664
665 if (automatic_removal_allowed &&
666 object->GetNumConsecutiveFramesBelowThreshold() >
667 kMaxNumDetectionFailures * 5) {
668 dead_objects.push_back(iter->first);
669 }
670 }
671
672 if (detector_ != NULL && automatic_removal_allowed) {
673 for (std::vector<std::string>::iterator iter = dead_objects.begin();
674 iter != dead_objects.end(); iter++) {
675 LOGE("Removing object! %s", iter->c_str());
676 ForgetTarget(*iter);
677 }
678 }
679 TimeLog("Tracked all objects.");
680
681 LOGV("%zu objects tracked!", objects_.size());
682 }
683
684 } // namespace tf_tracking
685