1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 17 #define TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 18 19 #include "tensorflow/tools/android/test/jni/object_tracking/config.h" 20 #include "tensorflow/tools/android/test/jni/object_tracking/geom.h" 21 #include "tensorflow/tools/android/test/jni/object_tracking/optical_flow.h" 22 #include "tensorflow/tools/android/test/jni/object_tracking/utils.h" 23 24 namespace tf_tracking { 25 26 // Class that helps OpticalFlow to speed up flow computation 27 // by caching coarse-grained flow. 28 class FlowCache { 29 public: FlowCache(const OpticalFlowConfig * const config)30 explicit FlowCache(const OpticalFlowConfig* const config) 31 : config_(config), 32 image_size_(config->image_size), 33 optical_flow_(config), 34 fullframe_matrix_(NULL) { 35 for (int i = 0; i < kNumCacheLevels; ++i) { 36 const int curr_dims = BlockDimForCacheLevel(i); 37 has_cache_[i] = new Image<bool>(curr_dims, curr_dims); 38 displacements_[i] = new Image<Point2f>(curr_dims, curr_dims); 39 } 40 } 41 ~FlowCache()42 ~FlowCache() { 43 for (int i = 0; i < kNumCacheLevels; ++i) { 44 SAFE_DELETE(has_cache_[i]); 45 SAFE_DELETE(displacements_[i]); 46 } 47 delete[](fullframe_matrix_); 48 fullframe_matrix_ = NULL; 49 } 50 NextFrame(ImageData * const new_frame,const float * const align_matrix23)51 void NextFrame(ImageData* const new_frame, 52 const float* const align_matrix23) { 53 ClearCache(); 54 SetFullframeAlignmentMatrix(align_matrix23); 55 optical_flow_.NextFrame(new_frame); 56 } 57 ClearCache()58 void ClearCache() { 59 for (int i = 0; i < kNumCacheLevels; ++i) { 60 has_cache_[i]->Clear(false); 61 } 62 delete[](fullframe_matrix_); 63 fullframe_matrix_ = NULL; 64 } 65 66 // Finds the flow at a point, using the cache for performance. FindFlowAtPoint(const float u_x,const float u_y,float * const flow_x,float * const flow_y)67 bool FindFlowAtPoint(const float u_x, const float u_y, 68 float* const flow_x, float* const flow_y) const { 69 // Get the best guess from the cache. 70 const Point2f guess_from_cache = LookupGuess(u_x, u_y); 71 72 *flow_x = guess_from_cache.x; 73 *flow_y = guess_from_cache.y; 74 75 // Now refine the guess using the image pyramid. 76 for (int pyramid_level = kMinNumPyramidLevelsToUseForAdjustment - 1; 77 pyramid_level >= 0; --pyramid_level) { 78 if (!optical_flow_.FindFlowAtPointSingleLevel( 79 pyramid_level, u_x, u_y, false, flow_x, flow_y)) { 80 return false; 81 } 82 } 83 84 return true; 85 } 86 87 // Determines the displacement of a point, and uses that to calculate a new 88 // position. 89 // Returns true iff the displacement determination worked and the new position 90 // is in the image. FindNewPositionOfPoint(const float u_x,const float u_y,float * final_x,float * final_y)91 bool FindNewPositionOfPoint(const float u_x, const float u_y, 92 float* final_x, float* final_y) const { 93 float flow_x; 94 float flow_y; 95 if (!FindFlowAtPoint(u_x, u_y, &flow_x, &flow_y)) { 96 return false; 97 } 98 99 // Add in the displacement to get the final position. 100 *final_x = u_x + flow_x; 101 *final_y = u_y + flow_y; 102 103 // Assign the best guess, if we're still in the image. 104 if (InRange(*final_x, 0.0f, static_cast<float>(image_size_.width) - 1) && 105 InRange(*final_y, 0.0f, static_cast<float>(image_size_.height) - 1)) { 106 return true; 107 } else { 108 return false; 109 } 110 } 111 112 // Comparison function for qsort. Compare(const void * a,const void * b)113 static int Compare(const void* a, const void* b) { 114 return *reinterpret_cast<const float*>(a) - 115 *reinterpret_cast<const float*>(b); 116 } 117 118 // Returns the median flow within the given bounding box as determined 119 // by a grid_width x grid_height grid. GetMedianFlow(const BoundingBox & bounding_box,const bool filter_by_fb_error,const int grid_width,const int grid_height)120 Point2f GetMedianFlow(const BoundingBox& bounding_box, 121 const bool filter_by_fb_error, 122 const int grid_width, 123 const int grid_height) const { 124 const int kMaxPoints = 100; 125 SCHECK(grid_width * grid_height <= kMaxPoints, 126 "Too many points for Median flow!"); 127 128 const BoundingBox valid_box = bounding_box.Intersect( 129 BoundingBox(0, 0, image_size_.width - 1, image_size_.height - 1)); 130 131 if (valid_box.GetArea() <= 0.0f) { 132 return Point2f(0, 0); 133 } 134 135 float x_deltas[kMaxPoints]; 136 float y_deltas[kMaxPoints]; 137 138 int curr_offset = 0; 139 for (int i = 0; i < grid_width; ++i) { 140 for (int j = 0; j < grid_height; ++j) { 141 const float x_in = valid_box.left_ + 142 (valid_box.GetWidth() * i) / (grid_width - 1); 143 144 const float y_in = valid_box.top_ + 145 (valid_box.GetHeight() * j) / (grid_height - 1); 146 147 float curr_flow_x; 148 float curr_flow_y; 149 const bool success = FindNewPositionOfPoint(x_in, y_in, 150 &curr_flow_x, &curr_flow_y); 151 152 if (success) { 153 x_deltas[curr_offset] = curr_flow_x; 154 y_deltas[curr_offset] = curr_flow_y; 155 ++curr_offset; 156 } else { 157 LOGW("Tracking failure!"); 158 } 159 } 160 } 161 162 if (curr_offset > 0) { 163 qsort(x_deltas, curr_offset, sizeof(*x_deltas), Compare); 164 qsort(y_deltas, curr_offset, sizeof(*y_deltas), Compare); 165 166 return Point2f(x_deltas[curr_offset / 2], y_deltas[curr_offset / 2]); 167 } 168 169 LOGW("No points were valid!"); 170 return Point2f(0, 0); 171 } 172 SetFullframeAlignmentMatrix(const float * const align_matrix23)173 void SetFullframeAlignmentMatrix(const float* const align_matrix23) { 174 if (align_matrix23 != NULL) { 175 if (fullframe_matrix_ == NULL) { 176 fullframe_matrix_ = new float[6]; 177 } 178 179 memcpy(fullframe_matrix_, align_matrix23, 180 6 * sizeof(fullframe_matrix_[0])); 181 } 182 } 183 184 private: LookupGuessFromLevel(const int cache_level,const float x,const float y)185 Point2f LookupGuessFromLevel( 186 const int cache_level, const float x, const float y) const { 187 // LOGE("Looking up guess at %5.2f %5.2f for level %d.", x, y, cache_level); 188 189 // Cutoff at the target level and use the matrix transform instead. 190 if (fullframe_matrix_ != NULL && cache_level == kCacheCutoff) { 191 const float xnew = x * fullframe_matrix_[0] + 192 y * fullframe_matrix_[1] + 193 fullframe_matrix_[2]; 194 const float ynew = x * fullframe_matrix_[3] + 195 y * fullframe_matrix_[4] + 196 fullframe_matrix_[5]; 197 198 return Point2f(xnew - x, ynew - y); 199 } 200 201 const int level_dim = BlockDimForCacheLevel(cache_level); 202 const int pixels_per_cache_block_x = 203 (image_size_.width + level_dim - 1) / level_dim; 204 const int pixels_per_cache_block_y = 205 (image_size_.height + level_dim - 1) / level_dim; 206 const int index_x = x / pixels_per_cache_block_x; 207 const int index_y = y / pixels_per_cache_block_y; 208 209 Point2f displacement; 210 if (!(*has_cache_[cache_level])[index_y][index_x]) { 211 (*has_cache_[cache_level])[index_y][index_x] = true; 212 213 // Get the lower cache level's best guess, if it exists. 214 displacement = cache_level >= kNumCacheLevels - 1 ? 215 Point2f(0, 0) : LookupGuessFromLevel(cache_level + 1, x, y); 216 // LOGI("Best guess at cache level %d is %5.2f, %5.2f.", cache_level, 217 // best_guess.x, best_guess.y); 218 219 // Find the center of the block. 220 const float center_x = (index_x + 0.5f) * pixels_per_cache_block_x; 221 const float center_y = (index_y + 0.5f) * pixels_per_cache_block_y; 222 const int pyramid_level = PyramidLevelForCacheLevel(cache_level); 223 224 // LOGI("cache level %d: [%d, %d (%5.2f / %d, %5.2f / %d)] " 225 // "Querying %5.2f, %5.2f at pyramid level %d, ", 226 // cache_level, index_x, index_y, 227 // x, pixels_per_cache_block_x, y, pixels_per_cache_block_y, 228 // center_x, center_y, pyramid_level); 229 230 // TODO(andrewharp): Turn on FB error filtering. 231 const bool success = optical_flow_.FindFlowAtPointSingleLevel( 232 pyramid_level, center_x, center_y, false, 233 &displacement.x, &displacement.y); 234 235 if (!success) { 236 LOGV("Computation of cached value failed for level %d!", cache_level); 237 } 238 239 // Store the value for later use. 240 (*displacements_[cache_level])[index_y][index_x] = displacement; 241 } else { 242 displacement = (*displacements_[cache_level])[index_y][index_x]; 243 } 244 245 // LOGI("Returning %5.2f, %5.2f for level %d", 246 // displacement.x, displacement.y, cache_level); 247 return displacement; 248 } 249 LookupGuess(const float x,const float y)250 Point2f LookupGuess(const float x, const float y) const { 251 if (x < 0 || x >= image_size_.width || y < 0 || y >= image_size_.height) { 252 return Point2f(0, 0); 253 } 254 255 // LOGI("Looking up guess at %5.2f %5.2f.", x, y); 256 if (kNumCacheLevels > 0) { 257 return LookupGuessFromLevel(0, x, y); 258 } else { 259 return Point2f(0, 0); 260 } 261 } 262 263 // Returns the number of cache bins in each dimension for a given level 264 // of the cache. BlockDimForCacheLevel(const int cache_level)265 int BlockDimForCacheLevel(const int cache_level) const { 266 // The highest (coarsest) cache level has a block dim of kCacheBranchFactor, 267 // thus if there are 4 cache levels, requesting level 3 (0-based) should 268 // return kCacheBranchFactor, level 2 should return kCacheBranchFactor^2, 269 // and so on. 270 int block_dim = kNumCacheLevels; 271 for (int curr_level = kNumCacheLevels - 1; curr_level > cache_level; 272 --curr_level) { 273 block_dim *= kCacheBranchFactor; 274 } 275 return block_dim; 276 } 277 278 // Returns the level of the image pyramid that a given cache level maps to. PyramidLevelForCacheLevel(const int cache_level)279 int PyramidLevelForCacheLevel(const int cache_level) const { 280 // Higher cache and pyramid levels have smaller dimensions. The highest 281 // cache level should refer to the highest image pyramid level. The 282 // lower, finer image pyramid levels are uncached (assuming 283 // kNumCacheLevels < kNumPyramidLevels). 284 return cache_level + (kNumPyramidLevels - kNumCacheLevels); 285 } 286 287 const OpticalFlowConfig* const config_; 288 289 const Size image_size_; 290 OpticalFlow optical_flow_; 291 292 float* fullframe_matrix_; 293 294 // Whether this value is currently present in the cache. 295 Image<bool>* has_cache_[kNumCacheLevels]; 296 297 // The cached displacement values. 298 Image<Point2f>* displacements_[kNumCacheLevels]; 299 300 TF_DISALLOW_COPY_AND_ASSIGN(FlowCache); 301 }; 302 303 } // namespace tf_tracking 304 305 #endif // TENSORFLOW_EXAMPLES_ANDROID_JNI_OBJECT_TRACKING_FLOW_CACHE_H_ 306