xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/ObjectDetectionPipeline.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ObjectDetectionPipeline.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ImageUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker namespace od
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker 
ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,std::unique_ptr<IDetectionResultDecoder> decoder)12*89c4ff92SAndroid Build Coastguard Worker ObjDetectionPipeline::ObjDetectionPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
13*89c4ff92SAndroid Build Coastguard Worker                                            std::unique_ptr<IDetectionResultDecoder> decoder) :
14*89c4ff92SAndroid Build Coastguard Worker     m_executor(std::move(executor)),
15*89c4ff92SAndroid Build Coastguard Worker     m_decoder(std::move(decoder)){}
16*89c4ff92SAndroid Build Coastguard Worker 
Inference(const cv::Mat & processed,common::InferenceResults<float> & result)17*89c4ff92SAndroid Build Coastguard Worker void od::ObjDetectionPipeline::Inference(const cv::Mat& processed, common::InferenceResults<float>& result)
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker     m_executor->Run(processed.data, processed.total() * processed.elemSize(), result);
20*89c4ff92SAndroid Build Coastguard Worker }
21*89c4ff92SAndroid Build Coastguard Worker 
PostProcessing(common::InferenceResults<float> & inferenceResult,const std::function<void (DetectedObjects)> & callback)22*89c4ff92SAndroid Build Coastguard Worker void ObjDetectionPipeline::PostProcessing(common::InferenceResults<float>& inferenceResult,
23*89c4ff92SAndroid Build Coastguard Worker         const std::function<void (DetectedObjects)>& callback)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     DetectedObjects detections = m_decoder->Decode(inferenceResult, m_inputImageSize,
26*89c4ff92SAndroid Build Coastguard Worker                                            m_executor->GetImageAspectRatio(), {});
27*89c4ff92SAndroid Build Coastguard Worker     if (callback)
28*89c4ff92SAndroid Build Coastguard Worker     {
29*89c4ff92SAndroid Build Coastguard Worker         callback(detections);
30*89c4ff92SAndroid Build Coastguard Worker     }
31*89c4ff92SAndroid Build Coastguard Worker }
32*89c4ff92SAndroid Build Coastguard Worker 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)33*89c4ff92SAndroid Build Coastguard Worker void ObjDetectionPipeline::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker     m_inputImageSize.m_Height = frame.rows;
36*89c4ff92SAndroid Build Coastguard Worker     m_inputImageSize.m_Width = frame.cols;
37*89c4ff92SAndroid Build Coastguard Worker     ResizeWithPad(frame, processed, m_processedFrame, m_executor->GetImageAspectRatio());
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker 
MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,float objectThreshold)40*89c4ff92SAndroid Build Coastguard Worker MobileNetSSDv1::MobileNetSSDv1(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
41*89c4ff92SAndroid Build Coastguard Worker                                float objectThreshold) :
42*89c4ff92SAndroid Build Coastguard Worker     ObjDetectionPipeline(std::move(executor),
43*89c4ff92SAndroid Build Coastguard Worker                          std::make_unique<SSDResultDecoder>(objectThreshold))
44*89c4ff92SAndroid Build Coastguard Worker {}
45*89c4ff92SAndroid Build Coastguard Worker 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)46*89c4ff92SAndroid Build Coastguard Worker void MobileNetSSDv1::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker     ObjDetectionPipeline::PreProcessing(frame, processed);
49*89c4ff92SAndroid Build Coastguard Worker     if (m_executor->GetInputDataType() == armnn::DataType::Float32)
50*89c4ff92SAndroid Build Coastguard Worker     {
51*89c4ff92SAndroid Build Coastguard Worker         // [0, 255] => [-1.0, 1.0]
52*89c4ff92SAndroid Build Coastguard Worker         processed.convertTo(processed, CV_32FC3, 1 / 127.5, -1);
53*89c4ff92SAndroid Build Coastguard Worker     }
54*89c4ff92SAndroid Build Coastguard Worker }
YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,float NMSThreshold,float ClsThreshold,float ObjectThreshold)55*89c4ff92SAndroid Build Coastguard Worker YoloV3Tiny::YoloV3Tiny(std::unique_ptr<common::ArmnnNetworkExecutor<float>> executor,
56*89c4ff92SAndroid Build Coastguard Worker                        float NMSThreshold, float ClsThreshold, float ObjectThreshold) :
57*89c4ff92SAndroid Build Coastguard Worker     ObjDetectionPipeline(std::move(executor),
58*89c4ff92SAndroid Build Coastguard Worker                          std::move(std::make_unique<YoloResultDecoder>(NMSThreshold,
59*89c4ff92SAndroid Build Coastguard Worker                                                                        ClsThreshold,
60*89c4ff92SAndroid Build Coastguard Worker                                                                        ObjectThreshold)))
61*89c4ff92SAndroid Build Coastguard Worker {}
62*89c4ff92SAndroid Build Coastguard Worker 
PreProcessing(const cv::Mat & frame,cv::Mat & processed)63*89c4ff92SAndroid Build Coastguard Worker void YoloV3Tiny::PreProcessing(const cv::Mat& frame, cv::Mat& processed)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     ObjDetectionPipeline::PreProcessing(frame, processed);
66*89c4ff92SAndroid Build Coastguard Worker     if (m_executor->GetInputDataType() == armnn::DataType::Float32)
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         processed.convertTo(processed, CV_32FC3);
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker 
CreatePipeline(common::PipelineOptions & config)72*89c4ff92SAndroid Build Coastguard Worker IPipelinePtr CreatePipeline(common::PipelineOptions& config)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker     auto executor = std::make_unique<common::ArmnnNetworkExecutor<float>>(config.m_ModelFilePath,
75*89c4ff92SAndroid Build Coastguard Worker                                                                           config.m_backends,
76*89c4ff92SAndroid Build Coastguard Worker                                                                           config.m_ProfilingEnabled);
77*89c4ff92SAndroid Build Coastguard Worker     if (config.m_ModelName == "SSD_MOBILE")
78*89c4ff92SAndroid Build Coastguard Worker     {
79*89c4ff92SAndroid Build Coastguard Worker         float detectionThreshold = 0.5;
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<od::MobileNetSSDv1>(std::move(executor),
82*89c4ff92SAndroid Build Coastguard Worker                                                     detectionThreshold
83*89c4ff92SAndroid Build Coastguard Worker         );
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker     else if (config.m_ModelName == "YOLO_V3_TINY")
86*89c4ff92SAndroid Build Coastguard Worker     {
87*89c4ff92SAndroid Build Coastguard Worker         float NMSThreshold = 0.6f;
88*89c4ff92SAndroid Build Coastguard Worker         float ClsThreshold = 0.6f;
89*89c4ff92SAndroid Build Coastguard Worker         float ObjectThreshold = 0.6f;
90*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<od::YoloV3Tiny>(std::move(executor),
91*89c4ff92SAndroid Build Coastguard Worker                                                 NMSThreshold,
92*89c4ff92SAndroid Build Coastguard Worker                                                 ClsThreshold,
93*89c4ff92SAndroid Build Coastguard Worker                                                 ObjectThreshold
94*89c4ff92SAndroid Build Coastguard Worker         );
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker     else
97*89c4ff92SAndroid Build Coastguard Worker     {
98*89c4ff92SAndroid Build Coastguard Worker         throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " supplied by user.");
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker }// namespace od
103