xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
16 
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <fstream>
21 #include <memory>
22 #include <streambuf>
23 #include <string>
24 
25 #include "absl/base/casts.h"
26 #include "absl/strings/ascii.h"
27 #include "tensorflow/core/lib/jpeg/jpeg_handle.h"
28 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/types.h"
32 #include "tensorflow/lite/profiling/time.h"
33 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
34 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
35 #include "tensorflow/lite/tools/evaluation/proto/preprocessing_steps.pb.h"
36 
37 namespace tflite {
38 namespace evaluation {
39 namespace {
40 
41 // We assume 3-channel RGB images.
42 const int kNumChannels = 3;
43 
44 // Returns the offset for the element in the raw image array based on the image
45 // height/weight & coordinates of a pixel (h, w, c).
ImageArrayOffset(int height,int width,int h,int w,int c)46 inline int ImageArrayOffset(int height, int width, int h, int w, int c) {
47   return (h * width + w) * kNumChannels + c;
48 }
49 
50 // Stores data and size information of an image.
51 struct ImageData {
52   uint32_t width;
53   uint32_t height;
54   std::unique_ptr<std::vector<float>> data;
55 
56   // GetData performs no checks.
GetDatatflite::evaluation::__anon0efeed720111::ImageData57   float GetData(int h, int w, int c) {
58     return data->at(ImageArrayOffset(height, width, h, w, c));
59   }
60 };
61 
62 // Loads the raw image.
LoadImageRaw(std::string * filename,ImageData * image_data)63 inline void LoadImageRaw(std::string* filename, ImageData* image_data) {
64   std::ifstream stream(filename->c_str(), std::ios::in | std::ios::binary);
65   std::vector<uint8_t> raw_data((std::istreambuf_iterator<char>(stream)),
66                                 std::istreambuf_iterator<char>());
67   std::vector<float>* orig_image = new std::vector<float>();
68   orig_image->reserve(raw_data.size());
69   for (int i = 0; i < raw_data.size(); ++i) {
70     orig_image->push_back(static_cast<float>(raw_data[i]));
71   }
72   image_data->data.reset(orig_image);
73 }
74 
75 // Loads the jpeg image.
LoadImageJpeg(std::string * filename,ImageData * image_data)76 inline void LoadImageJpeg(std::string* filename, ImageData* image_data) {
77   // Reads image.
78   std::ifstream t(*filename);
79   std::string image_str((std::istreambuf_iterator<char>(t)),
80                         std::istreambuf_iterator<char>());
81   const int fsize = image_str.size();
82   auto temp = absl::bit_cast<const uint8_t*>(image_str.data());
83   std::unique_ptr<uint8_t[]> original_image;
84   int original_width, original_height, original_channels;
85   tensorflow::jpeg::UncompressFlags flags;
86   // JDCT_ISLOW performs slower but more accurate pre-processing.
87   // This isn't always obvious in unit tests, but makes a difference during
88   // accuracy testing with ILSVRC dataset.
89   flags.dct_method = JDCT_ISLOW;
90   // We necessarily require a 3-channel image as the output.
91   flags.components = kNumChannels;
92   original_image.reset(Uncompress(temp, fsize, flags, &original_width,
93                                   &original_height, &original_channels,
94                                   nullptr));
95   // Copies the image data.
96   image_data->width = original_width;
97   image_data->height = original_height;
98   int original_size = original_height * original_width * original_channels;
99   std::vector<float>* float_image = new std::vector<float>();
100   float_image->reserve(original_size);
101   for (int i = 0; i < original_size; ++i) {
102     float_image->push_back(static_cast<float>(original_image[i]));
103   }
104   image_data->data.reset(float_image);
105 }
106 
107 // Central-cropping.
Crop(ImageData * image_data,const CroppingParams & crop_params)108 inline void Crop(ImageData* image_data, const CroppingParams& crop_params) {
109   int crop_height, crop_width;
110   int input_width = image_data->width;
111   int input_height = image_data->height;
112   if (crop_params.has_cropping_fraction()) {
113     crop_height =
114         static_cast<int>(round(crop_params.cropping_fraction() * input_height));
115     crop_width =
116         static_cast<int>(round(crop_params.cropping_fraction() * input_width));
117   } else if (crop_params.has_target_size()) {
118     crop_height = crop_params.target_size().height();
119     crop_width = crop_params.target_size().width();
120   }
121   if (crop_params.has_cropping_fraction() && crop_params.square_cropping()) {
122     crop_height = std::min(crop_height, crop_width);
123     crop_width = crop_height;
124   }
125   int start_w = static_cast<int>(round((input_width - crop_width) / 2.0));
126   int start_h = static_cast<int>(round((input_height - crop_height) / 2.0));
127   std::vector<float>* cropped_image = new std::vector<float>();
128   cropped_image->reserve(crop_height * crop_width * kNumChannels);
129   for (int in_h = start_h; in_h < start_h + crop_height; ++in_h) {
130     for (int in_w = start_w; in_w < start_w + crop_width; ++in_w) {
131       for (int c = 0; c < kNumChannels; ++c) {
132         cropped_image->push_back(image_data->GetData(in_h, in_w, c));
133       }
134     }
135   }
136   image_data->height = crop_height;
137   image_data->width = crop_width;
138   image_data->data.reset(cropped_image);
139 }
140 
141 // Performs billinear interpolation for 3-channel RGB image.
142 // See: https://en.wikipedia.org/wiki/Bilinear_interpolation
ResizeBilinear(ImageData * image_data,const ResizingParams & params)143 inline void ResizeBilinear(ImageData* image_data,
144                            const ResizingParams& params) {
145   tflite::ResizeBilinearParams resize_params;
146   resize_params.align_corners = false;
147   // TODO(b/143292772): Set this to true for more accurate behavior?
148   resize_params.half_pixel_centers = false;
149   tflite::RuntimeShape input_shape({1, static_cast<int>(image_data->height),
150                                     static_cast<int>(image_data->width),
151                                     kNumChannels});
152   // Calculates output size.
153   int output_height, output_width;
154   if (params.aspect_preserving()) {
155     float ratio_w =
156         params.target_size().width() / static_cast<float>(image_data->width);
157     float ratio_h =
158         params.target_size().height() / static_cast<float>(image_data->height);
159     if (ratio_w >= ratio_h) {
160       output_width = params.target_size().width();
161       output_height = static_cast<int>(round(image_data->height * ratio_w));
162     } else {
163       output_width = static_cast<int>(round(image_data->width * ratio_h));
164       output_height = params.target_size().height();
165     }
166   } else {
167     output_height = params.target_size().height();
168     output_width = params.target_size().width();
169   }
170   tflite::RuntimeShape output_size_dims({1, 1, 1, 2});
171   std::vector<int32_t> output_size_data = {output_height, output_width};
172   tflite::RuntimeShape output_shape(
173       {1, output_height, output_width, kNumChannels});
174   int output_size = output_width * output_height * kNumChannels;
175   std::vector<float>* output_data = new std::vector<float>(output_size, 0);
176   tflite::reference_ops::ResizeBilinear(
177       resize_params, input_shape, image_data->data->data(), output_size_dims,
178       output_size_data.data(), output_shape, output_data->data());
179   image_data->height = output_height;
180   image_data->width = output_width;
181   image_data->data.reset(output_data);
182 }
183 
184 // Pads the image to a pre-defined size.
Pad(ImageData * image_data,const PaddingParams & params)185 inline void Pad(ImageData* image_data, const PaddingParams& params) {
186   int output_width = params.target_size().width();
187   int output_height = params.target_size().height();
188   int pad_value = params.padding_value();
189   tflite::PadParams pad_params;
190   pad_params.left_padding_count = 4;
191   std::uninitialized_fill_n(pad_params.left_padding, 4, 0);
192   pad_params.left_padding[1] =
193       static_cast<int>(round((output_height - image_data->height) / 2.0));
194   pad_params.left_padding[2] =
195       static_cast<int>(round((output_width - image_data->width) / 2.0));
196   pad_params.right_padding_count = 4;
197   std::uninitialized_fill_n(pad_params.right_padding, 4, 0);
198   pad_params.right_padding[1] =
199       output_height - pad_params.left_padding[1] - image_data->height;
200   pad_params.right_padding[2] =
201       output_width - pad_params.left_padding[2] - image_data->width;
202   tflite::RuntimeShape input_shape({1, static_cast<int>(image_data->height),
203                                     static_cast<int>(image_data->width),
204                                     kNumChannels});
205   tflite::RuntimeShape output_shape(
206       {1, output_height, output_width, kNumChannels});
207   int output_size = output_width * output_height * kNumChannels;
208   std::vector<float>* output_data = new std::vector<float>(output_size, 0);
209   tflite::reference_ops::Pad(pad_params, input_shape, image_data->data->data(),
210                              &pad_value, output_shape, output_data->data());
211   image_data->height = output_height;
212   image_data->width = output_width;
213   image_data->data.reset(output_data);
214 }
215 
216 // Normalizes the image data to a specific range with mean and scale.
Normalize(ImageData * image_data,const NormalizationParams & params)217 inline void Normalize(ImageData* image_data,
218                       const NormalizationParams& params) {
219   float scale = params.scale();
220   float* data_end = image_data->data->data() + image_data->data->size();
221   if (params.has_channelwise_mean()) {
222     float mean = params.channelwise_mean();
223     for (float* data = image_data->data->data(); data < data_end; ++data) {
224       *data = (*data - mean) * scale;
225     }
226   } else {
227     float r_mean = params.means().r_mean();
228     float g_mean = params.means().g_mean();
229     float b_mean = params.means().b_mean();
230     for (float* data = image_data->data->data(); data < data_end;) {
231       *data = (*data - r_mean) * scale;
232       ++data;
233       *data = (*data - g_mean) * scale;
234       ++data;
235       *data = (*data - b_mean) * scale;
236       ++data;
237     }
238   }
239 }
240 }  // namespace
241 
Init()242 TfLiteStatus ImagePreprocessingStage::Init() {
243   if (!config_.has_specification() ||
244       !config_.specification().has_image_preprocessing_params()) {
245     LOG(ERROR) << "No preprocessing params";
246     return kTfLiteError;
247   }
248   const ImagePreprocessingParams& params =
249       config_.specification().image_preprocessing_params();
250   // Validates the cropping fraction.
251   for (const ImagePreprocessingStepParams& param : params.steps()) {
252     if (param.has_cropping_params()) {
253       const CroppingParams& crop_params = param.cropping_params();
254       if (crop_params.has_cropping_fraction() &&
255           (crop_params.cropping_fraction() <= 0 ||
256            crop_params.cropping_fraction() > 1.0)) {
257         LOG(ERROR) << "Invalid cropping fraction";
258         return kTfLiteError;
259       }
260     }
261   }
262   output_type_ = static_cast<TfLiteType>(params.output_type());
263   return kTfLiteOk;
264 }
265 
Run()266 TfLiteStatus ImagePreprocessingStage::Run() {
267   if (!image_path_) {
268     LOG(ERROR) << "Image path not set";
269     return kTfLiteError;
270   }
271 
272   ImageData image_data;
273   const ImagePreprocessingParams& params =
274       config_.specification().image_preprocessing_params();
275   int64_t start_us = profiling::time::NowMicros();
276   // Loads the image from file.
277   string image_ext = image_path_->substr(image_path_->find_last_of("."));
278   absl::AsciiStrToLower(&image_ext);
279   bool is_raw_image = (image_ext == ".rgb8");
280   if (image_ext == ".rgb8") {
281     LoadImageRaw(image_path_, &image_data);
282   } else if (image_ext == ".jpg" || image_ext == ".jpeg") {
283     LoadImageJpeg(image_path_, &image_data);
284   } else {
285     LOG(ERROR) << "Extension " << image_ext << " is not supported";
286     return kTfLiteError;
287   }
288 
289   // Cropping, padding and resizing are not supported with raw images since raw
290   // images do not contain image size information. Those steps are assumed to
291   // be done before raw images are generated.
292   for (const ImagePreprocessingStepParams& param : params.steps()) {
293     if (param.has_cropping_params()) {
294       if (is_raw_image) {
295         LOG(WARNING) << "Image cropping will not be performed on raw images";
296         continue;
297       }
298       Crop(&image_data, param.cropping_params());
299     } else if (param.has_resizing_params()) {
300       if (is_raw_image) {
301         LOG(WARNING) << "Image resizing will not be performed on raw images";
302         continue;
303       }
304       ResizeBilinear(&image_data, param.resizing_params());
305     } else if (param.has_padding_params()) {
306       if (is_raw_image) {
307         LOG(WARNING) << "Image padding will not be performed on raw images";
308         continue;
309       }
310       Pad(&image_data, param.padding_params());
311     } else if (param.has_normalization_params()) {
312       Normalize(&image_data, param.normalization_params());
313     }
314   }
315 
316   // Converts data to output type.
317   if (output_type_ == kTfLiteUInt8) {
318     uint8_preprocessed_image_.clear();
319     uint8_preprocessed_image_.reserve(image_data.data->size());
320     for (int i = 0; i < image_data.data->size(); ++i) {
321       uint8_preprocessed_image_.push_back(
322           static_cast<uint8_t>(image_data.data->at(i)));
323     }
324   } else if (output_type_ == kTfLiteInt8) {
325     int8_preprocessed_image_.clear();
326     int8_preprocessed_image_.reserve(image_data.data->size());
327     for (int i = 0; i < image_data.data->size(); ++i) {
328       int8_preprocessed_image_.push_back(
329           static_cast<int8_t>(image_data.data->at(i)));
330     }
331   } else if (output_type_ == kTfLiteFloat32) {
332     float_preprocessed_image_ = *image_data.data;
333   }
334 
335   latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us);
336   return kTfLiteOk;
337 }
338 
GetPreprocessedImageData()339 void* ImagePreprocessingStage::GetPreprocessedImageData() {
340   if (latency_stats_.count() == 0) return nullptr;
341 
342   if (output_type_ == kTfLiteUInt8) {
343     return uint8_preprocessed_image_.data();
344   } else if (output_type_ == kTfLiteInt8) {
345     return int8_preprocessed_image_.data();
346   } else if (output_type_ == kTfLiteFloat32) {
347     return float_preprocessed_image_.data();
348   }
349   return nullptr;
350 }
351 
LatestMetrics()352 EvaluationStageMetrics ImagePreprocessingStage::LatestMetrics() {
353   EvaluationStageMetrics metrics;
354   auto* latency_metrics =
355       metrics.mutable_process_metrics()->mutable_total_latency();
356   latency_metrics->set_last_us(latency_stats_.newest());
357   latency_metrics->set_max_us(latency_stats_.max());
358   latency_metrics->set_min_us(latency_stats_.min());
359   latency_metrics->set_sum_us(latency_stats_.sum());
360   latency_metrics->set_avg_us(latency_stats_.avg());
361   metrics.set_num_runs(static_cast<int>(latency_stats_.count()));
362   return metrics;
363 }
364 
365 }  // namespace evaluation
366 }  // namespace tflite
367