xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/include/ObjectDetectionPipeline.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnNetworkExecutor.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "YoloResultDecoder.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "SSDResultDecoder.hpp"
11*89c4ff92SAndroid Build Coastguard Worker # include "ImageUtils.hpp"
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <opencv2/opencv.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace od
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker /**
18*89c4ff92SAndroid Build Coastguard Worker  * Generic object detection pipeline with 3 steps: data pre-processing, inference execution and inference
19*89c4ff92SAndroid Build Coastguard Worker  * result post-processing.
20*89c4ff92SAndroid Build Coastguard Worker  *
21*89c4ff92SAndroid Build Coastguard Worker  */
22*89c4ff92SAndroid Build Coastguard Worker class ObjDetectionPipeline {
23*89c4ff92SAndroid Build Coastguard Worker public:
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker     /**
26*89c4ff92SAndroid Build Coastguard Worker      * Creates object detection pipeline with given network executor and decoder.
27*89c4ff92SAndroid Build Coastguard Worker      * @param executor - unique pointer to inference runner
28*89c4ff92SAndroid Build Coastguard Worker      * @param decoder - unique pointer to inference results decoder
29*89c4ff92SAndroid Build Coastguard Worker      */
30*89c4ff92SAndroid Build Coastguard Worker     ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
31*89c4ff92SAndroid Build Coastguard Worker                          std::unique_ptr<IDetectionResultDecoder> decoder);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     /**
34*89c4ff92SAndroid Build Coastguard Worker      * @brief Standard image pre-processing implementation.
35*89c4ff92SAndroid Build Coastguard Worker      *
36*89c4ff92SAndroid Build Coastguard Worker      * Re-sizes an image keeping aspect ratio, pads if necessary to fit the network input layer dimensions.
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker      * @param[in] frame - input image, expected data type is uint8.
39*89c4ff92SAndroid Build Coastguard Worker      * @param[out] processed - output image, data type is preserved.
40*89c4ff92SAndroid Build Coastguard Worker      */
41*89c4ff92SAndroid Build Coastguard Worker     virtual void PreProcessing(const cv::Mat& frame, cv::Mat& processed);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker     /**
44*89c4ff92SAndroid Build Coastguard Worker      * @brief Executes inference
45*89c4ff92SAndroid Build Coastguard Worker      *
46*89c4ff92SAndroid Build Coastguard Worker      * Calls inference runner provided during instance construction.
47*89c4ff92SAndroid Build Coastguard Worker      *
48*89c4ff92SAndroid Build Coastguard Worker      * @param[in] processed - input inference data. Data type should be aligned with input tensor.
49*89c4ff92SAndroid Build Coastguard Worker      * @param[out] result - raw floating point inference results.
50*89c4ff92SAndroid Build Coastguard Worker      */
51*89c4ff92SAndroid Build Coastguard Worker     virtual void Inference(const cv::Mat& processed, common::InferenceResults<float>& result);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     /**
54*89c4ff92SAndroid Build Coastguard Worker      * @brief Standard inference results post-processing implementation.
55*89c4ff92SAndroid Build Coastguard Worker      *
56*89c4ff92SAndroid Build Coastguard Worker      * Decodes inference results using decoder provided during construction.
57*89c4ff92SAndroid Build Coastguard Worker      *
58*89c4ff92SAndroid Build Coastguard Worker      * @param[in] inferenceResult - inference results to be decoded.
59*89c4ff92SAndroid Build Coastguard Worker      * @param[in] callback - a function to be called after successful inference results decoding.
60*89c4ff92SAndroid Build Coastguard Worker      */
61*89c4ff92SAndroid Build Coastguard Worker     virtual void PostProcessing(common::InferenceResults<float>& inferenceResult,
62*89c4ff92SAndroid Build Coastguard Worker                                 const std::function<void (DetectedObjects)>& callback);
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker protected:
65*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<common::ArmnnNetworkExecutor<float>> m_executor;
66*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IDetectionResultDecoder> m_decoder;
67*89c4ff92SAndroid Build Coastguard Worker     common::Size m_inputImageSize{};
68*89c4ff92SAndroid Build Coastguard Worker     cv::Mat m_processedFrame;
69*89c4ff92SAndroid Build Coastguard Worker };
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker /**
72*89c4ff92SAndroid Build Coastguard Worker  * Specific to Yolo v3 tiny object detection pipeline implementation.
73*89c4ff92SAndroid Build Coastguard Worker  */
74*89c4ff92SAndroid Build Coastguard Worker class YoloV3Tiny: public ObjDetectionPipeline{
75*89c4ff92SAndroid Build Coastguard Worker public:
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     /**
78*89c4ff92SAndroid Build Coastguard Worker      * Constructs object detection pipeline for Yolo v3 tiny network.
79*89c4ff92SAndroid Build Coastguard Worker      *
80*89c4ff92SAndroid Build Coastguard Worker      * Network input is expected to be uint8 or fp32. Data range [0, 255].
81*89c4ff92SAndroid Build Coastguard Worker      * Network output is FP32.
82*89c4ff92SAndroid Build Coastguard Worker      *
83*89c4ff92SAndroid Build Coastguard Worker      * @param executor[in] - unique pointer to inference runner
84*89c4ff92SAndroid Build Coastguard Worker      * @param NMSThreshold[in] - non max suppression threshold for decoding step
85*89c4ff92SAndroid Build Coastguard Worker      * @param ClsThreshold[in] -  class probability threshold for decoding step
86*89c4ff92SAndroid Build Coastguard Worker      * @param ObjectThreshold[in] - detected object score threshold for decoding step
87*89c4ff92SAndroid Build Coastguard Worker      */
88*89c4ff92SAndroid Build Coastguard Worker     YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
89*89c4ff92SAndroid Build Coastguard Worker                float NMSThreshold, float ClsThreshold, float ObjectThreshold);
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     /**
92*89c4ff92SAndroid Build Coastguard Worker      * @brief Yolo v3 tiny image pre-processing implementation.
93*89c4ff92SAndroid Build Coastguard Worker      *
94*89c4ff92SAndroid Build Coastguard Worker      * On top of the standard pre-processing, converts input data type according to the network input tensor data type.
95*89c4ff92SAndroid Build Coastguard Worker      * Supported data types: uint8 and float32.
96*89c4ff92SAndroid Build Coastguard Worker      *
97*89c4ff92SAndroid Build Coastguard Worker      * @param[in] original - input image data
98*89c4ff92SAndroid Build Coastguard Worker      * @param[out] processed - image data ready to be used for inference.
99*89c4ff92SAndroid Build Coastguard Worker      */
100*89c4ff92SAndroid Build Coastguard Worker     void PreProcessing(const cv::Mat& original, cv::Mat& processed);
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker };
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker /**
105*89c4ff92SAndroid Build Coastguard Worker  * Specific to MobileNet SSD v1 object detection pipeline implementation.
106*89c4ff92SAndroid Build Coastguard Worker  */
107*89c4ff92SAndroid Build Coastguard Worker class MobileNetSSDv1: public ObjDetectionPipeline {
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker public:
110*89c4ff92SAndroid Build Coastguard Worker     /**
111*89c4ff92SAndroid Build Coastguard Worker      * Constructs object detection pipeline for MobileNet SSD network.
112*89c4ff92SAndroid Build Coastguard Worker      *
113*89c4ff92SAndroid Build Coastguard Worker      * Network input is expected to be uint8 or fp32. Data range [-1, 1].
114*89c4ff92SAndroid Build Coastguard Worker      * Network output is FP32.
115*89c4ff92SAndroid Build Coastguard Worker      *
116*89c4ff92SAndroid Build Coastguard Worker      * @param[in] - unique pointer to inference runner
117*89c4ff92SAndroid Build Coastguard Worker      * @paramp[in] objectThreshold - detected object score threshold for decoding step
118*89c4ff92SAndroid Build Coastguard Worker      */
119*89c4ff92SAndroid Build Coastguard Worker     MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
120*89c4ff92SAndroid Build Coastguard Worker                    float objectThreshold);
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     /**
123*89c4ff92SAndroid Build Coastguard Worker      * @brief MobileNet SSD image pre-processing implementation.
124*89c4ff92SAndroid Build Coastguard Worker      *
125*89c4ff92SAndroid Build Coastguard Worker      * On top of the standard pre-processing, converts input data type according to the network input tensor data type
126*89c4ff92SAndroid Build Coastguard Worker      * and scales input data from [0, 255] to [-1, 1] for FP32 input.
127*89c4ff92SAndroid Build Coastguard Worker      *
128*89c4ff92SAndroid Build Coastguard Worker      * Supported input data types: uint8 and float32.
129*89c4ff92SAndroid Build Coastguard Worker      *
130*89c4ff92SAndroid Build Coastguard Worker      * @param[in] original - input image data
131*89c4ff92SAndroid Build Coastguard Worker      * @param processed[out] - image data ready to be used for inference.
132*89c4ff92SAndroid Build Coastguard Worker      */
133*89c4ff92SAndroid Build Coastguard Worker     void PreProcessing(const cv::Mat& original, cv::Mat& processed);
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker };
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker using IPipelinePtr = std::unique_ptr<od::ObjDetectionPipeline>;
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker /**
140*89c4ff92SAndroid Build Coastguard Worker  * Constructs object detection pipeline based on configuration provided.
141*89c4ff92SAndroid Build Coastguard Worker  *
142*89c4ff92SAndroid Build Coastguard Worker  * @param[in] config - object detection pipeline configuration.
143*89c4ff92SAndroid Build Coastguard Worker  *
144*89c4ff92SAndroid Build Coastguard Worker  * @return unique pointer to object detection pipeline.
145*89c4ff92SAndroid Build Coastguard Worker  */
146*89c4ff92SAndroid Build Coastguard Worker IPipelinePtr CreatePipeline(common::PipelineOptions& config);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker }// namespace od