xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/android/test/jni/object_tracking/flow_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
18 
19 #include "tensorflow/tools/android/test/jni/object_tracking/config.h"
20 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h"
21 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h"
22 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h"
23 
24 namespace tf_tracking {
25 
26 // Class that helps OpticalFlow to speed up flow computation
27 // by caching coarse-grained flow.
28 class FlowCache {
29  public:
FlowCache(const OpticalFlowConfig * const config)30   explicit FlowCache(const OpticalFlowConfig* const config)
31       : config_(config),
32         image_size_(config->image_size),
33         optical_flow_(config),
34         fullframe_matrix_(NULL) {
35     for (int i = 0; i < kNumCacheLevels; ++i) {
36       const int curr_dims = BlockDimForCacheLevel(i);
37       has_cache_[i] = new Image<bool>(curr_dims, curr_dims);
38       displacements_[i] = new Image<Point2f>(curr_dims, curr_dims);
39     }
40   }
41 
~FlowCache()42   ~FlowCache() {
43     for (int i = 0; i < kNumCacheLevels; ++i) {
44       SAFE_DELETE(has_cache_[i]);
45       SAFE_DELETE(displacements_[i]);
46     }
47     delete[](fullframe_matrix_);
48     fullframe_matrix_ = NULL;
49   }
50 
NextFrame(ImageData * const new_frame,const float * const align_matrix23)51   void NextFrame(ImageData* const new_frame,
52                  const float* const align_matrix23) {
53     ClearCache();
54     SetFullframeAlignmentMatrix(align_matrix23);
55     optical_flow_.NextFrame(new_frame);
56   }
57 
ClearCache()58   void ClearCache() {
59     for (int i = 0; i < kNumCacheLevels; ++i) {
60       has_cache_[i]->Clear(false);
61     }
62     delete[](fullframe_matrix_);
63     fullframe_matrix_ = NULL;
64   }
65 
66   // Finds the flow at a point, using the cache for performance.
FindFlowAtPoint(const float u_x,const float u_y,float * const flow_x,float * const flow_y)67   bool FindFlowAtPoint(const float u_x, const float u_y,
68                        float* const flow_x, float* const flow_y) const {
69     // Get the best guess from the cache.
70     const Point2f guess_from_cache = LookupGuess(u_x, u_y);
71 
72     *flow_x = guess_from_cache.x;
73     *flow_y = guess_from_cache.y;
74 
75     // Now refine the guess using the image pyramid.
76     for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1;
77         pyramid_level >= 0; --pyramid_level) {
78       if (!optical_flow_.FindFlowAtPointSingleLevel(
79           pyramid_level, u_x, u_y, false, flow_x, flow_y)) {
80         return false;
81       }
82     }
83 
84     return true;
85   }
86 
87   // Determines the displacement of a point, and uses that to calculate a new
88   // position.
89   // Returns true iff the displacement determination worked and the new position
90   // is in the image.
FindNewPositionOfPoint(const float u_x,const float u_y,float * final_x,float * final_y)91   bool FindNewPositionOfPoint(const float u_x, const float u_y,
92                               float* final_x, float* final_y) const {
93     float flow_x;
94     float flow_y;
95     if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) {
96       return false;
97     }
98 
99     // Add in the displacement to get the final position.
100     *final_x = u_x + flow_x;
101     *final_y = u_y + flow_y;
102 
103     // Assign the best guess, if we're still in the image.
104     if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) &&
105         InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) {
106       return true;
107     } else {
108       return false;
109     }
110   }
111 
112   // Comparison function for qsort.
Compare(const void * a,const void * b)113   static int Compare(const void* a, const void* b) {
114     return *reinterpret_cast<const float*>(a) -
115            *reinterpret_cast<const float*>(b);
116   }
117 
118   // Returns the median flow within the given bounding box as determined
119   // by a grid_width x grid_height grid.
GetMedianFlow(const BoundingBox & bounding_box,const bool filter_by_fb_error,const int grid_width,const int grid_height)120   Point2f GetMedianFlow(const BoundingBox& bounding_box,
121                         const bool filter_by_fb_error,
122                         const int grid_width,
123                         const int grid_height) const {
124     const int kMaxPoints = 100;
125     SCHECK(grid_width * grid_height <= kMaxPoints,
126           "Too many points for Median flow!");
127 
128     const BoundingBox valid_box = bounding_box.Intersect(
129         BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1));
130 
131     if (valid_box.GetArea() <= 0.0f) {
132       return Point2f(0, 0);
133     }
134 
135     float x_deltas[kMaxPoints];
136     float y_deltas[kMaxPoints];
137 
138     int curr_offset = 0;
139     for (int i = 0; i < grid_width; ++i) {
140       for (int j = 0; j < grid_height; ++j) {
141         const float x_in = valid_box.left_ +
142             (valid_box.GetWidth() * i) / (grid_width - 1);
143 
144         const float y_in = valid_box.top_ +
145             (valid_box.GetHeight() * j) / (grid_height - 1);
146 
147         float curr_flow_x;
148         float curr_flow_y;
149         const bool success = FindNewPositionOfPoint(x_in, y_in,
150                                                     &curr_flow_x, &curr_flow_y);
151 
152         if (success) {
153           x_deltas[curr_offset] = curr_flow_x;
154           y_deltas[curr_offset] = curr_flow_y;
155           ++curr_offset;
156         } else {
157           LOGW("Tracking failure!");
158         }
159       }
160     }
161 
162     if (curr_offset > 0) {
163       qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare);
164       qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare);
165 
166       return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]);
167     }
168 
169     LOGW("No points were valid!");
170     return Point2f(0, 0);
171   }
172 
SetFullframeAlignmentMatrix(const float * const align_matrix23)173   void SetFullframeAlignmentMatrix(const float* const align_matrix23) {
174     if (align_matrix23 != NULL) {
175       if (fullframe_matrix_ == NULL) {
176         fullframe_matrix_ = new float[6];
177       }
178 
179       memcpy(fullframe_matrix_, align_matrix23,
180              6 * sizeof(fullframe_matrix_[0]));
181     }
182   }
183 
184  private:
LookupGuessFromLevel(const int cache_level,const float x,const float y)185   Point2f LookupGuessFromLevel(
186       const int cache_level, const float x, const float y) const {
187     // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level);
188 
189     // Cutoff at the target level and use the matrix transform instead.
190     if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) {
191       const float xnew = x * fullframe_matrix_[0] +
192                          y * fullframe_matrix_[1] +
193                              fullframe_matrix_[2];
194       const float ynew = x * fullframe_matrix_[3] +
195                          y * fullframe_matrix_[4] +
196                              fullframe_matrix_[5];
197 
198       return Point2f(xnew - x, ynew - y);
199     }
200 
201     const int level_dim = BlockDimForCacheLevel(cache_level);
202     const int pixels_per_cache_block_x =
203         (image_size_.width + level_dim - 1) / level_dim;
204     const int pixels_per_cache_block_y =
205         (image_size_.height + level_dim - 1) / level_dim;
206     const int index_x = x / pixels_per_cache_block_x;
207     const int index_y = y / pixels_per_cache_block_y;
208 
209     Point2f displacement;
210     if (!(*has_cache_[cache_level])[index_y][index_x]) {
211       (*has_cache_[cache_level])[index_y][index_x] = true;
212 
213       // Get the lower cache level's best guess, if it exists.
214       displacement = cache_level >= kNumCacheLevels - 1 ?
215           Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y);
216       // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level,
217       //      best_guess.x, best_guess.y);
218 
219       // Find the center of the block.
220       const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x;
221       const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y;
222       const int pyramid_level = PyramidLevelForCacheLevel(cache_level);
223 
224       // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] "
225       //      "Querying %5.2f, %5.2f at pyramid level %d, ",
226       //      cache_level, index_x, index_y,
227       //      x, pixels_per_cache_block_x, y, pixels_per_cache_block_y,
228       //      center_x, center_y, pyramid_level);
229 
230       // TODO(andrewharp): Turn on FB error filtering.
231       const bool success = optical_flow_.FindFlowAtPointSingleLevel(
232           pyramid_level, center_x, center_y, false,
233           &displacement.x, &displacement.y);
234 
235       if (!success) {
236         LOGV("Computation of cached value failed for level %d!", cache_level);
237       }
238 
239       // Store the value for later use.
240       (*displacements_[cache_level])[index_y][index_x] = displacement;
241     } else {
242       displacement = (*displacements_[cache_level])[index_y][index_x];
243     }
244 
245     // LOGI("Returning %5.2f, %5.2f for level %d",
246     //      displacement.x, displacement.y, cache_level);
247     return displacement;
248   }
249 
LookupGuess(const float x,const float y)250   Point2f LookupGuess(const float x, const float y) const {
251     if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) {
252       return Point2f(0, 0);
253     }
254 
255     // LOGI("Looking up guess at %5.2f %5.2f.", x, y);
256     if (kNumCacheLevels > 0) {
257       return LookupGuessFromLevel(0, x, y);
258     } else {
259       return Point2f(0, 0);
260     }
261   }
262 
263   // Returns the number of cache bins in each dimension for a given level
264   // of the cache.
BlockDimForCacheLevel(const int cache_level)265   int BlockDimForCacheLevel(const int cache_level) const {
266     // The highest (coarsest) cache level has a block dim of kCacheBranchFactor,
267     // thus if there are 4 cache levels, requesting level 3 (0-based) should
268     // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2,
269     // and so on.
270     int block_dim = kNumCacheLevels;
271     for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level;
272         --curr_level) {
273       block_dim *= kCacheBranchFactor;
274     }
275     return block_dim;
276   }
277 
278   // Returns the level of the image pyramid that a given cache level maps to.
PyramidLevelForCacheLevel(const int cache_level)279   int PyramidLevelForCacheLevel(const int cache_level) const {
280     // Higher cache and pyramid levels have smaller dimensions. The highest
281     // cache level should refer to the highest image pyramid level. The
282     // lower, finer image pyramid levels are uncached (assuming
283     // kNumCacheLevels < kNumPyramidLevels).
284     return cache_level + (kNumPyramidLevels - kNumCacheLevels);
285   }
286 
287   const OpticalFlowConfig* const config_;
288 
289   const Size image_size_;
290   OpticalFlow optical_flow_;
291 
292   float* fullframe_matrix_;
293 
294   // Whether this value is currently present in the cache.
295   Image<bool>* has_cache_[kNumCacheLevels];
296 
297   // The cached displacement values.
298   Image<Point2f>* displacements_[kNumCacheLevels];
299 
300   TF_DISALLOW_COPY_AND_ASSIGN(FlowCache);
301 };
302 
303 }  // namespace tf_tracking
304 
305 #endif  // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_
306