1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h"
16
17 #include <algorithm>
18 #include <cmath>
19 #include <cstdint>
20 #include <fstream>
21 #include <memory>
22 #include <streambuf>
23 #include <string>
24
25 #include "absl/base/casts.h"
26 #include "absl/strings/ascii.h"
27 #include "tensorflow/core/lib/jpeg/jpeg_handle.h"
28 #include "tensorflow/core/lib/jpeg/jpeg_mem.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31 #include "tensorflow/lite/kernels/internal/types.h"
32 #include "tensorflow/lite/profiling/time.h"
33 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
34 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
35 #include "tensorflow/lite/tools/evaluation/proto/preprocessing_steps.pb.h"
36
37 namespace tflite {
38 namespace evaluation {
39 namespace {
40
41 // We assume 3-channel RGB images.
42 const int kNumChannels = 3;
43
44 // Returns the offset for the element in the raw image array based on the image
45 // height/weight & coordinates of a pixel (h, w, c).
ImageArrayOffset(int height,int width,int h,int w,int c)46 inline int ImageArrayOffset(int height, int width, int h, int w, int c) {
47 return (h * width + w) * kNumChannels + c;
48 }
49
50 // Stores data and size information of an image.
51 struct ImageData {
52 uint32_t width;
53 uint32_t height;
54 std::unique_ptr<std::vector<float>> data;
55
56 // GetData performs no checks.
GetDatatflite::evaluation::__anon0efeed720111::ImageData57 float GetData(int h, int w, int c) {
58 return data->at(ImageArrayOffset(height, width, h, w, c));
59 }
60 };
61
62 // Loads the raw image.
LoadImageRaw(std::string * filename,ImageData * image_data)63 inline void LoadImageRaw(std::string* filename, ImageData* image_data) {
64 std::ifstream stream(filename->c_str(), std::ios::in | std::ios::binary);
65 std::vector<uint8_t> raw_data((std::istreambuf_iterator<char>(stream)),
66 std::istreambuf_iterator<char>());
67 std::vector<float>* orig_image = new std::vector<float>();
68 orig_image->reserve(raw_data.size());
69 for (int i = 0; i < raw_data.size(); ++i) {
70 orig_image->push_back(static_cast<float>(raw_data[i]));
71 }
72 image_data->data.reset(orig_image);
73 }
74
75 // Loads the jpeg image.
LoadImageJpeg(std::string * filename,ImageData * image_data)76 inline void LoadImageJpeg(std::string* filename, ImageData* image_data) {
77 // Reads image.
78 std::ifstream t(*filename);
79 std::string image_str((std::istreambuf_iterator<char>(t)),
80 std::istreambuf_iterator<char>());
81 const int fsize = image_str.size();
82 auto temp = absl::bit_cast<const uint8_t*>(image_str.data());
83 std::unique_ptr<uint8_t[]> original_image;
84 int original_width, original_height, original_channels;
85 tensorflow::jpeg::UncompressFlags flags;
86 // JDCT_ISLOW performs slower but more accurate pre-processing.
87 // This isn't always obvious in unit tests, but makes a difference during
88 // accuracy testing with ILSVRC dataset.
89 flags.dct_method = JDCT_ISLOW;
90 // We necessarily require a 3-channel image as the output.
91 flags.components = kNumChannels;
92 original_image.reset(Uncompress(temp, fsize, flags, &original_width,
93 &original_height, &original_channels,
94 nullptr));
95 // Copies the image data.
96 image_data->width = original_width;
97 image_data->height = original_height;
98 int original_size = original_height * original_width * original_channels;
99 std::vector<float>* float_image = new std::vector<float>();
100 float_image->reserve(original_size);
101 for (int i = 0; i < original_size; ++i) {
102 float_image->push_back(static_cast<float>(original_image[i]));
103 }
104 image_data->data.reset(float_image);
105 }
106
107 // Central-cropping.
Crop(ImageData * image_data,const CroppingParams & crop_params)108 inline void Crop(ImageData* image_data, const CroppingParams& crop_params) {
109 int crop_height, crop_width;
110 int input_width = image_data->width;
111 int input_height = image_data->height;
112 if (crop_params.has_cropping_fraction()) {
113 crop_height =
114 static_cast<int>(round(crop_params.cropping_fraction() * input_height));
115 crop_width =
116 static_cast<int>(round(crop_params.cropping_fraction() * input_width));
117 } else if (crop_params.has_target_size()) {
118 crop_height = crop_params.target_size().height();
119 crop_width = crop_params.target_size().width();
120 }
121 if (crop_params.has_cropping_fraction() && crop_params.square_cropping()) {
122 crop_height = std::min(crop_height, crop_width);
123 crop_width = crop_height;
124 }
125 int start_w = static_cast<int>(round((input_width - crop_width) / 2.0));
126 int start_h = static_cast<int>(round((input_height - crop_height) / 2.0));
127 std::vector<float>* cropped_image = new std::vector<float>();
128 cropped_image->reserve(crop_height * crop_width * kNumChannels);
129 for (int in_h = start_h; in_h < start_h + crop_height; ++in_h) {
130 for (int in_w = start_w; in_w < start_w + crop_width; ++in_w) {
131 for (int c = 0; c < kNumChannels; ++c) {
132 cropped_image->push_back(image_data->GetData(in_h, in_w, c));
133 }
134 }
135 }
136 image_data->height = crop_height;
137 image_data->width = crop_width;
138 image_data->data.reset(cropped_image);
139 }
140
141 // Performs billinear interpolation for 3-channel RGB image.
142 // See: https://en.wikipedia.org/wiki/Bilinear_interpolation
ResizeBilinear(ImageData * image_data,const ResizingParams & params)143 inline void ResizeBilinear(ImageData* image_data,
144 const ResizingParams& params) {
145 tflite::ResizeBilinearParams resize_params;
146 resize_params.align_corners = false;
147 // TODO(b/143292772): Set this to true for more accurate behavior?
148 resize_params.half_pixel_centers = false;
149 tflite::RuntimeShape input_shape({1, static_cast<int>(image_data->height),
150 static_cast<int>(image_data->width),
151 kNumChannels});
152 // Calculates output size.
153 int output_height, output_width;
154 if (params.aspect_preserving()) {
155 float ratio_w =
156 params.target_size().width() / static_cast<float>(image_data->width);
157 float ratio_h =
158 params.target_size().height() / static_cast<float>(image_data->height);
159 if (ratio_w >= ratio_h) {
160 output_width = params.target_size().width();
161 output_height = static_cast<int>(round(image_data->height * ratio_w));
162 } else {
163 output_width = static_cast<int>(round(image_data->width * ratio_h));
164 output_height = params.target_size().height();
165 }
166 } else {
167 output_height = params.target_size().height();
168 output_width = params.target_size().width();
169 }
170 tflite::RuntimeShape output_size_dims({1, 1, 1, 2});
171 std::vector<int32_t> output_size_data = {output_height, output_width};
172 tflite::RuntimeShape output_shape(
173 {1, output_height, output_width, kNumChannels});
174 int output_size = output_width * output_height * kNumChannels;
175 std::vector<float>* output_data = new std::vector<float>(output_size, 0);
176 tflite::reference_ops::ResizeBilinear(
177 resize_params, input_shape, image_data->data->data(), output_size_dims,
178 output_size_data.data(), output_shape, output_data->data());
179 image_data->height = output_height;
180 image_data->width = output_width;
181 image_data->data.reset(output_data);
182 }
183
184 // Pads the image to a pre-defined size.
Pad(ImageData * image_data,const PaddingParams & params)185 inline void Pad(ImageData* image_data, const PaddingParams& params) {
186 int output_width = params.target_size().width();
187 int output_height = params.target_size().height();
188 int pad_value = params.padding_value();
189 tflite::PadParams pad_params;
190 pad_params.left_padding_count = 4;
191 std::uninitialized_fill_n(pad_params.left_padding, 4, 0);
192 pad_params.left_padding[1] =
193 static_cast<int>(round((output_height - image_data->height) / 2.0));
194 pad_params.left_padding[2] =
195 static_cast<int>(round((output_width - image_data->width) / 2.0));
196 pad_params.right_padding_count = 4;
197 std::uninitialized_fill_n(pad_params.right_padding, 4, 0);
198 pad_params.right_padding[1] =
199 output_height - pad_params.left_padding[1] - image_data->height;
200 pad_params.right_padding[2] =
201 output_width - pad_params.left_padding[2] - image_data->width;
202 tflite::RuntimeShape input_shape({1, static_cast<int>(image_data->height),
203 static_cast<int>(image_data->width),
204 kNumChannels});
205 tflite::RuntimeShape output_shape(
206 {1, output_height, output_width, kNumChannels});
207 int output_size = output_width * output_height * kNumChannels;
208 std::vector<float>* output_data = new std::vector<float>(output_size, 0);
209 tflite::reference_ops::Pad(pad_params, input_shape, image_data->data->data(),
210 &pad_value, output_shape, output_data->data());
211 image_data->height = output_height;
212 image_data->width = output_width;
213 image_data->data.reset(output_data);
214 }
215
216 // Normalizes the image data to a specific range with mean and scale.
Normalize(ImageData * image_data,const NormalizationParams & params)217 inline void Normalize(ImageData* image_data,
218 const NormalizationParams& params) {
219 float scale = params.scale();
220 float* data_end = image_data->data->data() + image_data->data->size();
221 if (params.has_channelwise_mean()) {
222 float mean = params.channelwise_mean();
223 for (float* data = image_data->data->data(); data < data_end; ++data) {
224 *data = (*data - mean) * scale;
225 }
226 } else {
227 float r_mean = params.means().r_mean();
228 float g_mean = params.means().g_mean();
229 float b_mean = params.means().b_mean();
230 for (float* data = image_data->data->data(); data < data_end;) {
231 *data = (*data - r_mean) * scale;
232 ++data;
233 *data = (*data - g_mean) * scale;
234 ++data;
235 *data = (*data - b_mean) * scale;
236 ++data;
237 }
238 }
239 }
240 } // namespace
241
Init()242 TfLiteStatus ImagePreprocessingStage::Init() {
243 if (!config_.has_specification() ||
244 !config_.specification().has_image_preprocessing_params()) {
245 LOG(ERROR) << "No preprocessing params";
246 return kTfLiteError;
247 }
248 const ImagePreprocessingParams& params =
249 config_.specification().image_preprocessing_params();
250 // Validates the cropping fraction.
251 for (const ImagePreprocessingStepParams& param : params.steps()) {
252 if (param.has_cropping_params()) {
253 const CroppingParams& crop_params = param.cropping_params();
254 if (crop_params.has_cropping_fraction() &&
255 (crop_params.cropping_fraction() <= 0 ||
256 crop_params.cropping_fraction() > 1.0)) {
257 LOG(ERROR) << "Invalid cropping fraction";
258 return kTfLiteError;
259 }
260 }
261 }
262 output_type_ = static_cast<TfLiteType>(params.output_type());
263 return kTfLiteOk;
264 }
265
Run()266 TfLiteStatus ImagePreprocessingStage::Run() {
267 if (!image_path_) {
268 LOG(ERROR) << "Image path not set";
269 return kTfLiteError;
270 }
271
272 ImageData image_data;
273 const ImagePreprocessingParams& params =
274 config_.specification().image_preprocessing_params();
275 int64_t start_us = profiling::time::NowMicros();
276 // Loads the image from file.
277 string image_ext = image_path_->substr(image_path_->find_last_of("."));
278 absl::AsciiStrToLower(&image_ext);
279 bool is_raw_image = (image_ext == ".rgb8");
280 if (image_ext == ".rgb8") {
281 LoadImageRaw(image_path_, &image_data);
282 } else if (image_ext == ".jpg" || image_ext == ".jpeg") {
283 LoadImageJpeg(image_path_, &image_data);
284 } else {
285 LOG(ERROR) << "Extension " << image_ext << " is not supported";
286 return kTfLiteError;
287 }
288
289 // Cropping, padding and resizing are not supported with raw images since raw
290 // images do not contain image size information. Those steps are assumed to
291 // be done before raw images are generated.
292 for (const ImagePreprocessingStepParams& param : params.steps()) {
293 if (param.has_cropping_params()) {
294 if (is_raw_image) {
295 LOG(WARNING) << "Image cropping will not be performed on raw images";
296 continue;
297 }
298 Crop(&image_data, param.cropping_params());
299 } else if (param.has_resizing_params()) {
300 if (is_raw_image) {
301 LOG(WARNING) << "Image resizing will not be performed on raw images";
302 continue;
303 }
304 ResizeBilinear(&image_data, param.resizing_params());
305 } else if (param.has_padding_params()) {
306 if (is_raw_image) {
307 LOG(WARNING) << "Image padding will not be performed on raw images";
308 continue;
309 }
310 Pad(&image_data, param.padding_params());
311 } else if (param.has_normalization_params()) {
312 Normalize(&image_data, param.normalization_params());
313 }
314 }
315
316 // Converts data to output type.
317 if (output_type_ == kTfLiteUInt8) {
318 uint8_preprocessed_image_.clear();
319 uint8_preprocessed_image_.reserve(image_data.data->size());
320 for (int i = 0; i < image_data.data->size(); ++i) {
321 uint8_preprocessed_image_.push_back(
322 static_cast<uint8_t>(image_data.data->at(i)));
323 }
324 } else if (output_type_ == kTfLiteInt8) {
325 int8_preprocessed_image_.clear();
326 int8_preprocessed_image_.reserve(image_data.data->size());
327 for (int i = 0; i < image_data.data->size(); ++i) {
328 int8_preprocessed_image_.push_back(
329 static_cast<int8_t>(image_data.data->at(i)));
330 }
331 } else if (output_type_ == kTfLiteFloat32) {
332 float_preprocessed_image_ = *image_data.data;
333 }
334
335 latency_stats_.UpdateStat(profiling::time::NowMicros() - start_us);
336 return kTfLiteOk;
337 }
338
GetPreprocessedImageData()339 void* ImagePreprocessingStage::GetPreprocessedImageData() {
340 if (latency_stats_.count() == 0) return nullptr;
341
342 if (output_type_ == kTfLiteUInt8) {
343 return uint8_preprocessed_image_.data();
344 } else if (output_type_ == kTfLiteInt8) {
345 return int8_preprocessed_image_.data();
346 } else if (output_type_ == kTfLiteFloat32) {
347 return float_preprocessed_image_.data();
348 }
349 return nullptr;
350 }
351
LatestMetrics()352 EvaluationStageMetrics ImagePreprocessingStage::LatestMetrics() {
353 EvaluationStageMetrics metrics;
354 auto* latency_metrics =
355 metrics.mutable_process_metrics()->mutable_total_latency();
356 latency_metrics->set_last_us(latency_stats_.newest());
357 latency_metrics->set_max_us(latency_stats_.max());
358 latency_metrics->set_min_us(latency_stats_.min());
359 latency_metrics->set_sum_us(latency_stats_.sum());
360 latency_metrics->set_avg_us(latency_stats_.avg());
361 metrics.set_num_runs(static_cast<int>(latency_stats_.count()));
362 return metrics;
363 }
364
365 } // namespace evaluation
366 } // namespace tflite
367