1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 18 19 #include <map> 20 #include <string> 21 22 #include "tensorflow/tools/android/test/jni/object_tracking/config.h" 23 #include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h" 24 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h" 25 #include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h" 26 #include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h" 27 #include "tensorflow/tools/android/test/jni/object_tracking/logging.h" 28 #include "tensorflow/tools/android/test/jni/object_tracking/object_model.h" 29 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" 30 #include "tensorflow/tools/android/test/jni/object_tracking/time_log.h" 31 #include "tensorflow/tools/android/test/jni/object_tracking/tracked_object.h" 32 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h" 33 34 namespace tf_tracking { 35 36 typedef std::map<const std::string, TrackedObject*> TrackedObjectMap; 37 38 inline std::ostream& operator<<(std::ostream& stream, 39 const TrackedObjectMap& map) { 40 for (TrackedObjectMap::const_iterator iter = map.begin(); 41 iter != map.end(); ++iter) { 42 const TrackedObject& tracked_object = *iter->second; 43 const std::string& key = iter->first; 44 stream << key << ": " << tracked_object; 45 } 46 return stream; 47 } 48 49 50 // ObjectTracker is the highest-level class in the tracking/detection framework. 51 // It handles basic image processing, keypoint detection, keypoint tracking, 52 // object tracking, and object detection/relocalization. 53 class ObjectTracker { 54 public: 55 ObjectTracker(const TrackerConfig* const config, 56 ObjectDetectorBase* const detector); 57 virtual ~ObjectTracker(); 58 NextFrame(const uint8_t * const new_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)59 virtual void NextFrame(const uint8_t* const new_frame, 60 const int64_t timestamp, 61 const float* const alignment_matrix_2x3) { 62 NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3); 63 } 64 65 // Called upon the arrival of a new frame of raw data. 66 // Does all image processing, keypoint detection, and object 67 // tracking/detection for registered objects. 68 // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that 69 // represents the main transformation that has happened between the last 70 // and the current frame. 71 // Argument align_level is the pyramid level (where 0 == finest) that 72 // the matrix is valid for. 73 virtual void NextFrame(const uint8_t* const new_frame, 74 const uint8_t* const uv_frame, const int64_t timestamp, 75 const float* const alignment_matrix_2x3); 76 77 virtual void RegisterNewObjectWithAppearance(const std::string& id, 78 const uint8_t* const new_frame, 79 const BoundingBox& bounding_box); 80 81 // Updates the position of a tracked object, given that it was known to be at 82 // a certain position at some point in the past. 83 virtual void SetPreviousPositionOfObject(const std::string& id, 84 const BoundingBox& bounding_box, 85 const int64_t timestamp); 86 87 // Sets the current position of the object in the most recent frame provided. 88 virtual void SetCurrentPositionOfObject(const std::string& id, 89 const BoundingBox& bounding_box); 90 91 // Tells the ObjectTracker to stop tracking a target. 92 void ForgetTarget(const std::string& id); 93 94 // Fills the given out_data buffer with the latest detected keypoint 95 // correspondences, first scaled by scale_factor (to adjust for downsampling 96 // that may have occurred elsewhere), then packed in a fixed-point format. 97 int GetKeypointsPacked(uint16_t* const out_data, 98 const float scale_factor) const; 99 100 // Copy the keypoint arrays after computeFlow is called. 101 // out_data should be at least kMaxKeypoints * kKeypointStep long. 102 // Currently, its format is [x1 y1 found x2 y2 score] repeated N times, 103 // where N is the number of keypoints tracked. N is returned as the result. 104 int GetKeypoints(const bool only_found, float* const out_data) const; 105 106 // Returns the current position of a box, given that it was at a certain 107 // position at the given time. 108 BoundingBox TrackBox(const BoundingBox& region, 109 const int64_t timestamp) const; 110 111 // Returns the number of frames that have been passed to NextFrame(). GetNumFrames()112 inline int GetNumFrames() const { 113 return num_frames_; 114 } 115 HaveObject(const std::string & id)116 inline bool HaveObject(const std::string& id) const { 117 return objects_.find(id) != objects_.end(); 118 } 119 120 // Returns the TrackedObject associated with the given id. GetObject(const std::string & id)121 inline const TrackedObject* GetObject(const std::string& id) const { 122 TrackedObjectMap::const_iterator iter = objects_.find(id); 123 CHECK_ALWAYS(iter != objects_.end(), 124 "Unknown object key! \"%s\"", id.c_str()); 125 TrackedObject* const object = iter->second; 126 return object; 127 } 128 129 // Returns the TrackedObject associated with the given id. GetObject(const std::string & id)130 inline TrackedObject* GetObject(const std::string& id) { 131 TrackedObjectMap::iterator iter = objects_.find(id); 132 CHECK_ALWAYS(iter != objects_.end(), 133 "Unknown object key! \"%s\"", id.c_str()); 134 TrackedObject* const object = iter->second; 135 return object; 136 } 137 IsObjectVisible(const std::string & id)138 bool IsObjectVisible(const std::string& id) const { 139 SCHECK(HaveObject(id), "Don't have this object."); 140 141 const TrackedObject* object = GetObject(id); 142 return object->IsVisible(); 143 } 144 145 virtual void Draw(const int canvas_width, const int canvas_height, 146 const float* const frame_to_canvas) const; 147 148 protected: 149 // Creates a new tracked object at the given position. 150 // If an object model is provided, then that model will be associated with the 151 // object. If not, a new model may be created from the appearance at the 152 // initial position and registered with the object detector. 153 virtual TrackedObject* MaybeAddObject(const std::string& id, 154 const Image<uint8_t>& image, 155 const BoundingBox& bounding_box, 156 const ObjectModelBase* object_model); 157 158 // Find the keypoints in the frame before the current frame. 159 // If only one frame exists, keypoints will be found in that frame. 160 void ComputeKeypoints(const bool cached_ok = false); 161 162 // Finds the correspondences for all the points in the current pair of frames. 163 // Stores the results in the given FramePair. 164 void FindCorrespondences(FramePair* const curr_change) const; 165 GetNthIndexFromEnd(const int offset)166 inline int GetNthIndexFromEnd(const int offset) const { 167 return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset); 168 } 169 170 BoundingBox TrackBox(const BoundingBox& region, 171 const FramePair& frame_pair) const; 172 IncrementFrameIndex()173 inline void IncrementFrameIndex() { 174 // Move the current framechange index up. 175 ++num_frames_; 176 ++curr_num_frame_pairs_; 177 178 // If we've got too many, push up the start of the queue. 179 if (curr_num_frame_pairs_ > kNumFrames) { 180 first_frame_index_ = GetNthIndexFromStart(1); 181 --curr_num_frame_pairs_; 182 } 183 } 184 GetNthIndexFromStart(const int offset)185 inline int GetNthIndexFromStart(const int offset) const { 186 SCHECK(offset >= 0 && offset < curr_num_frame_pairs_, 187 "Offset out of range! %d out of %d.", offset, curr_num_frame_pairs_); 188 return (first_frame_index_ + offset) % kNumFrames; 189 } 190 191 void TrackObjects(); 192 193 const std::unique_ptr<const TrackerConfig> config_; 194 195 const int frame_width_; 196 const int frame_height_; 197 198 int64_t curr_time_; 199 200 int num_frames_; 201 202 TrackedObjectMap objects_; 203 204 FlowCache flow_cache_; 205 206 KeypointDetector keypoint_detector_; 207 208 int curr_num_frame_pairs_; 209 int first_frame_index_; 210 211 std::unique_ptr<ImageData> frame1_; 212 std::unique_ptr<ImageData> frame2_; 213 214 FramePair frame_pairs_[kNumFrames]; 215 216 std::unique_ptr<ObjectDetectorBase> detector_; 217 218 int num_detected_; 219 220 private: 221 void TrackTarget(TrackedObject* const object); 222 223 bool GetBestObjectForDetection( 224 const Detection& detection, TrackedObject** match) const; 225 226 void ProcessDetections(std::vector<Detection>* const detections); 227 228 void DetectTargets(); 229 230 // Temp object used in ObjectTracker::CreateNewExample. 231 mutable std::vector<BoundingSquare> squares; 232 233 friend std::ostream& operator<<(std::ostream& stream, 234 const ObjectTracker& tracker); 235 236 TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker); 237 }; 238 239 inline std::ostream& operator<<(std::ostream& stream, 240 const ObjectTracker& tracker) { 241 stream << "Frame size: " << tracker.frame_width_ << "x" 242 << tracker.frame_height_ << std::endl; 243 244 stream << "Num frames: " << tracker.num_frames_ << std::endl; 245 246 stream << "Curr time: " << tracker.curr_time_ << std::endl; 247 248 const int first_frame_index = tracker.GetNthIndexFromStart(0); 249 const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index]; 250 251 const int last_frame_index = tracker.GetNthIndexFromEnd(0); 252 const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index]; 253 254 stream << "first frame: " << first_frame_index << "," 255 << first_frame_pair.end_time_ << " " 256 << "last frame: " << last_frame_index << "," 257 << last_frame_pair.end_time_ << " diff: " 258 << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms" 259 << std::endl; 260 261 stream << "Tracked targets:"; 262 stream << tracker.objects_; 263 264 return stream; 265 } 266 267 } // namespace tf_tracking 268 269 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_ 270