xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/evaluation/stages/image_preprocessing_stage.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_PREPROCESSING_STAGE_H_
16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_PREPROCESSING_STAGE_H_
17 
18 #include <stdint.h>
19 
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/util/stats_calculator.h"
26 #include "tensorflow/lite/tools/evaluation/evaluation_stage.h"
27 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
28 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
29 #include "tensorflow/lite/tools/evaluation/proto/preprocessing_steps.pb.h"
30 
31 namespace tflite {
32 namespace evaluation {
33 
34 // EvaluationStage to read contents of an image and preprocess it for inference.
35 // Currently only supports JPEGs.
36 class ImagePreprocessingStage : public EvaluationStage {
37  public:
ImagePreprocessingStage(const EvaluationStageConfig & config)38   explicit ImagePreprocessingStage(const EvaluationStageConfig& config)
39       : EvaluationStage(config) {}
40 
41   TfLiteStatus Init() override;
42 
43   TfLiteStatus Run() override;
44 
45   EvaluationStageMetrics LatestMetrics() override;
46 
~ImagePreprocessingStage()47   ~ImagePreprocessingStage() override {}
48 
49   // Call before Run().
SetImagePath(std::string * image_path)50   void SetImagePath(std::string* image_path) { image_path_ = image_path; }
51 
52   // Provides preprocessing output.
53   void* GetPreprocessedImageData();
54 
55  private:
56   std::string* image_path_ = nullptr;
57   TfLiteType output_type_;
58   tensorflow::Stat<int64_t> latency_stats_;
59 
60   // One of the following 3 vectors will be populated based on output_type_.
61   std::vector<float> float_preprocessed_image_;
62   std::vector<int8_t> int8_preprocessed_image_;
63   std::vector<uint8_t> uint8_preprocessed_image_;
64 };
65 
66 // Helper class to build a new ImagePreprocessingParams.
67 class ImagePreprocessingConfigBuilder {
68  public:
ImagePreprocessingConfigBuilder(const std::string & name,TfLiteType output_type)69   ImagePreprocessingConfigBuilder(const std::string& name,
70                                   TfLiteType output_type) {
71     config_.set_name(name);
72     config_.mutable_specification()
73         ->mutable_image_preprocessing_params()
74         ->set_output_type(static_cast<int>(output_type));
75   }
76 
77   // Adds a cropping step with cropping fraction.
78   void AddCroppingStep(float cropping_fraction,
79                        bool use_square_cropping = false) {
80     ImagePreprocessingStepParams params;
81     params.mutable_cropping_params()->set_cropping_fraction(cropping_fraction);
82     params.mutable_cropping_params()->set_square_cropping(use_square_cropping);
83     config_.mutable_specification()
84         ->mutable_image_preprocessing_params()
85         ->mutable_steps()
86         ->Add(std::move(params));
87   }
88 
89   // Adds a cropping step with target size.
90   void AddCroppingStep(uint32_t width, uint32_t height,
91                        bool use_square_cropping = false) {
92     ImagePreprocessingStepParams params;
93     params.mutable_cropping_params()->mutable_target_size()->set_height(height);
94     params.mutable_cropping_params()->mutable_target_size()->set_width(width);
95     params.mutable_cropping_params()->set_square_cropping(use_square_cropping);
96     config_.mutable_specification()
97         ->mutable_image_preprocessing_params()
98         ->mutable_steps()
99         ->Add(std::move(params));
100   }
101 
102   // Adds a resizing step.
AddResizingStep(uint32_t width,uint32_t height,bool aspect_preserving)103   void AddResizingStep(uint32_t width, uint32_t height,
104                        bool aspect_preserving) {
105     ImagePreprocessingStepParams params;
106     params.mutable_resizing_params()->set_aspect_preserving(aspect_preserving);
107     params.mutable_resizing_params()->mutable_target_size()->set_height(height);
108     params.mutable_resizing_params()->mutable_target_size()->set_width(width);
109     config_.mutable_specification()
110         ->mutable_image_preprocessing_params()
111         ->mutable_steps()
112         ->Add(std::move(params));
113   }
114 
115   // Adds a padding step.
AddPaddingStep(uint32_t width,uint32_t height,int value)116   void AddPaddingStep(uint32_t width, uint32_t height, int value) {
117     ImagePreprocessingStepParams params;
118     params.mutable_padding_params()->mutable_target_size()->set_height(height);
119     params.mutable_padding_params()->mutable_target_size()->set_width(width);
120     params.mutable_padding_params()->set_padding_value(value);
121     config_.mutable_specification()
122         ->mutable_image_preprocessing_params()
123         ->mutable_steps()
124         ->Add(std::move(params));
125   }
126 
127   // Adds a square padding step.
AddSquarePaddingStep(int value)128   void AddSquarePaddingStep(int value) {
129     ImagePreprocessingStepParams params;
130     params.mutable_padding_params()->set_square_padding(true);
131     params.mutable_padding_params()->set_padding_value(value);
132     config_.mutable_specification()
133         ->mutable_image_preprocessing_params()
134         ->mutable_steps()
135         ->Add(std::move(params));
136   }
137 
138   // Adds a subtracting means step.
AddPerChannelNormalizationStep(float r_mean,float g_mean,float b_mean,float scale)139   void AddPerChannelNormalizationStep(float r_mean, float g_mean, float b_mean,
140                                       float scale) {
141     ImagePreprocessingStepParams params;
142     params.mutable_normalization_params()->mutable_means()->set_r_mean(r_mean);
143     params.mutable_normalization_params()->mutable_means()->set_g_mean(g_mean);
144     params.mutable_normalization_params()->mutable_means()->set_b_mean(b_mean);
145     params.mutable_normalization_params()->set_scale(scale);
146     config_.mutable_specification()
147         ->mutable_image_preprocessing_params()
148         ->mutable_steps()
149         ->Add(std::move(params));
150   }
151 
152   // Adds a normalization step.
AddNormalizationStep(float mean,float scale)153   void AddNormalizationStep(float mean, float scale) {
154     ImagePreprocessingStepParams params;
155     params.mutable_normalization_params()->set_channelwise_mean(mean);
156     params.mutable_normalization_params()->set_scale(scale);
157     config_.mutable_specification()
158         ->mutable_image_preprocessing_params()
159         ->mutable_steps()
160         ->Add(std::move(params));
161   }
162 
163   // Adds a normalization step with default value.
AddDefaultNormalizationStep()164   void AddDefaultNormalizationStep() {
165     switch (
166         config_.specification().image_preprocessing_params().output_type()) {
167       case kTfLiteFloat32:
168         AddNormalizationStep(127.5, 1.0 / 127.5);
169         break;
170       case kTfLiteUInt8:
171         break;
172       case kTfLiteInt8:
173         AddNormalizationStep(128.0, 1.0);
174         break;
175       default:
176         LOG(ERROR) << "Type not supported";
177         break;
178     }
179   }
180 
build()181   EvaluationStageConfig build() { return std::move(config_); }
182 
183  private:
184   EvaluationStageConfig config_;
185 };
186 
187 }  // namespace evaluation
188 }  // namespace tflite
189 
190 #endif  // TENSORFLOW_LITE_TOOLS_EVALUATION_STAGES_IMAGE_PREPROCESSING_STAGE_H_
191