xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/android/test/jni/object_tracking/object_tracker.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
18 
19 #include <map>
20 #include <string>
21 
22 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
23 #include "tensorflow/tools/android/test/jni/object_tracking/flow_cache.h"
24 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h"
25 #include "tensorflow/tools/android/test/jni/object_tracking/integral_image.h"
26 #include "tensorflow/tools/android/test/jni/object_tracking/keypoint_detector.h"
27 #include "tensorflow/tools/android/test/jni/object_tracking/logging.h"
28 #include "tensorflow/tools/android/test/jni/object_tracking/object_model.h"
29 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h"
30 #include "tensorflow/tools/android/test/jni/object_tracking/time_log.h"
31 #include "tensorflow/tools/android/test/jni/object_tracking/tracked_object.h"
32 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h"
33 
34 namespace tf_tracking {
35 
36 typedef std::map<const std::string, TrackedObject*> TrackedObjectMap;
37 
38 inline std::ostream& operator<<(std::ostream& stream,
39                                 const TrackedObjectMap& map) {
40   for (TrackedObjectMap::const_iterator iter = map.begin();
41       iter != map.end(); ++iter) {
42     const TrackedObject& tracked_object = *iter->second;
43     const std::string& key = iter->first;
44     stream << key << ": " << tracked_object;
45   }
46   return stream;
47 }
48 
49 
50 // ObjectTracker is the highest-level class in the tracking/detection framework.
51 // It handles basic image processing, keypoint detection, keypoint tracking,
52 // object tracking, and object detection/relocalization.
53 class ObjectTracker {
54  public:
55   ObjectTracker(const TrackerConfig* const config,
56                 ObjectDetectorBase* const detector);
57   virtual ~ObjectTracker();
58 
NextFrame(const uint8_t * const new_frame,const int64_t timestamp,const float * const alignment_matrix_2x3)59   virtual void NextFrame(const uint8_t* const new_frame,
60                          const int64_t timestamp,
61                          const float* const alignment_matrix_2x3) {
62     NextFrame(new_frame, NULL, timestamp, alignment_matrix_2x3);
63   }
64 
65   // Called upon the arrival of a new frame of raw data.
66   // Does all image processing, keypoint detection, and object
67   // tracking/detection for registered objects.
68   // Argument alignment_matrix_2x3 is a 2x3 matrix (stored row-wise) that
69   // represents the main transformation that has happened between the last
70   // and the current frame.
71   // Argument align_level is the pyramid level (where 0 == finest) that
72   // the matrix is valid for.
73   virtual void NextFrame(const uint8_t* const new_frame,
74                          const uint8_t* const uv_frame, const int64_t timestamp,
75                          const float* const alignment_matrix_2x3);
76 
77   virtual void RegisterNewObjectWithAppearance(const std::string& id,
78                                                const uint8_t* const new_frame,
79                                                const BoundingBox& bounding_box);
80 
81   // Updates the position of a tracked object, given that it was known to be at
82   // a certain position at some point in the past.
83   virtual void SetPreviousPositionOfObject(const std::string& id,
84                                            const BoundingBox& bounding_box,
85                                            const int64_t timestamp);
86 
87   // Sets the current position of the object in the most recent frame provided.
88   virtual void SetCurrentPositionOfObject(const std::string& id,
89                                           const BoundingBox& bounding_box);
90 
91   // Tells the ObjectTracker to stop tracking a target.
92   void ForgetTarget(const std::string& id);
93 
94   // Fills the given out_data buffer with the latest detected keypoint
95   // correspondences, first scaled by scale_factor (to adjust for downsampling
96   // that may have occurred elsewhere), then packed in a fixed-point format.
97   int GetKeypointsPacked(uint16_t* const out_data,
98                          const float scale_factor) const;
99 
100   // Copy the keypoint arrays after computeFlow is called.
101   // out_data should be at least kMaxKeypoints * kKeypointStep long.
102   // Currently, its format is [x1 y1 found x2 y2 score] repeated N times,
103   // where N is the number of keypoints tracked.  N is returned as the result.
104   int GetKeypoints(const bool only_found, float* const out_data) const;
105 
106   // Returns the current position of a box, given that it was at a certain
107   // position at the given time.
108   BoundingBox TrackBox(const BoundingBox& region,
109                        const int64_t timestamp) const;
110 
111   // Returns the number of frames that have been passed to NextFrame().
GetNumFrames()112   inline int GetNumFrames() const {
113     return num_frames_;
114   }
115 
HaveObject(const std::string & id)116   inline bool HaveObject(const std::string& id) const {
117     return objects_.find(id) != objects_.end();
118   }
119 
120   // Returns the TrackedObject associated with the given id.
GetObject(const std::string & id)121   inline const TrackedObject* GetObject(const std::string& id) const {
122     TrackedObjectMap::const_iterator iter = objects_.find(id);
123     CHECK_ALWAYS(iter != objects_.end(),
124                  "Unknown object key! \"%s\"", id.c_str());
125     TrackedObject* const object = iter->second;
126     return object;
127   }
128 
129   // Returns the TrackedObject associated with the given id.
GetObject(const std::string & id)130   inline TrackedObject* GetObject(const std::string& id) {
131     TrackedObjectMap::iterator iter = objects_.find(id);
132     CHECK_ALWAYS(iter != objects_.end(),
133                  "Unknown object key! \"%s\"", id.c_str());
134     TrackedObject* const object = iter->second;
135     return object;
136   }
137 
IsObjectVisible(const std::string & id)138   bool IsObjectVisible(const std::string& id) const {
139     SCHECK(HaveObject(id), "Don't have this object.");
140 
141     const TrackedObject* object = GetObject(id);
142     return object->IsVisible();
143   }
144 
145   virtual void Draw(const int canvas_width, const int canvas_height,
146                     const float* const frame_to_canvas) const;
147 
148  protected:
149   // Creates a new tracked object at the given position.
150   // If an object model is provided, then that model will be associated with the
151   // object. If not, a new model may be created from the appearance at the
152   // initial position and registered with the object detector.
153   virtual TrackedObject* MaybeAddObject(const std::string& id,
154                                         const Image<uint8_t>& image,
155                                         const BoundingBox& bounding_box,
156                                         const ObjectModelBase* object_model);
157 
158   // Find the keypoints in the frame before the current frame.
159   // If only one frame exists, keypoints will be found in that frame.
160   void ComputeKeypoints(const bool cached_ok = false);
161 
162   // Finds the correspondences for all the points in the current pair of frames.
163   // Stores the results in the given FramePair.
164   void FindCorrespondences(FramePair* const curr_change) const;
165 
GetNthIndexFromEnd(const int offset)166   inline int GetNthIndexFromEnd(const int offset) const {
167     return GetNthIndexFromStart(curr_num_frame_pairs_ - 1 - offset);
168   }
169 
170   BoundingBox TrackBox(const BoundingBox& region,
171                        const FramePair& frame_pair) const;
172 
IncrementFrameIndex()173   inline void IncrementFrameIndex() {
174     // Move the current framechange index up.
175     ++num_frames_;
176     ++curr_num_frame_pairs_;
177 
178     // If we've got too many, push up the start of the queue.
179     if (curr_num_frame_pairs_ > kNumFrames) {
180       first_frame_index_ = GetNthIndexFromStart(1);
181       --curr_num_frame_pairs_;
182     }
183   }
184 
GetNthIndexFromStart(const int offset)185   inline int GetNthIndexFromStart(const int offset) const {
186     SCHECK(offset >= 0 && offset < curr_num_frame_pairs_,
187           "Offset out of range!  %d out of %d.", offset, curr_num_frame_pairs_);
188     return (first_frame_index_ + offset) % kNumFrames;
189   }
190 
191   void TrackObjects();
192 
193   const std::unique_ptr<const TrackerConfig> config_;
194 
195   const int frame_width_;
196   const int frame_height_;
197 
198   int64_t curr_time_;
199 
200   int num_frames_;
201 
202   TrackedObjectMap objects_;
203 
204   FlowCache flow_cache_;
205 
206   KeypointDetector keypoint_detector_;
207 
208   int curr_num_frame_pairs_;
209   int first_frame_index_;
210 
211   std::unique_ptr<ImageData> frame1_;
212   std::unique_ptr<ImageData> frame2_;
213 
214   FramePair frame_pairs_[kNumFrames];
215 
216   std::unique_ptr<ObjectDetectorBase> detector_;
217 
218   int num_detected_;
219 
220  private:
221   void TrackTarget(TrackedObject* const object);
222 
223   bool GetBestObjectForDetection(
224       const Detection& detection, TrackedObject** match) const;
225 
226   void ProcessDetections(std::vector<Detection>* const detections);
227 
228   void DetectTargets();
229 
230   // Temp object used in ObjectTracker::CreateNewExample.
231   mutable std::vector<BoundingSquare> squares;
232 
233   friend std::ostream& operator<<(std::ostream& stream,
234                                   const ObjectTracker& tracker);
235 
236   TF_DISALLOW_COPY_AND_ASSIGN(ObjectTracker);
237 };
238 
239 inline std::ostream& operator<<(std::ostream& stream,
240                                 const ObjectTracker& tracker) {
241   stream << "Frame size: " << tracker.frame_width_ << "x"
242          << tracker.frame_height_ << std::endl;
243 
244   stream << "Num frames: " << tracker.num_frames_ << std::endl;
245 
246   stream << "Curr time: " << tracker.curr_time_ << std::endl;
247 
248   const int first_frame_index = tracker.GetNthIndexFromStart(0);
249   const FramePair& first_frame_pair = tracker.frame_pairs_[first_frame_index];
250 
251   const int last_frame_index = tracker.GetNthIndexFromEnd(0);
252   const FramePair& last_frame_pair = tracker.frame_pairs_[last_frame_index];
253 
254   stream << "first frame: " << first_frame_index << ","
255          << first_frame_pair.end_time_ << "    "
256          << "last frame: " << last_frame_index << ","
257          << last_frame_pair.end_time_ << "   diff: "
258          << last_frame_pair.end_time_ - first_frame_pair.end_time_ << "ms"
259          << std::endl;
260 
261   stream << "Tracked targets:";
262   stream << tracker.objects_;
263 
264   return stream;
265 }
266 
267 }  // namespace tf_tracking
268 
269 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_OBJECT_TRACKER_H_
270