xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/test/PipelineTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <catch.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <opencv2/opencv.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include "ObjectDetectionPipeline.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "Types.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
GetResourceFilePath(const std::string & filename)10*89c4ff92SAndroid Build Coastguard Worker static std::string GetResourceFilePath(const std::string& filename)
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker     std::string testResources = TEST_RESOURCE_DIR;
13*89c4ff92SAndroid Build Coastguard Worker     if (0 == testResources.size())
14*89c4ff92SAndroid Build Coastguard Worker     {
15*89c4ff92SAndroid Build Coastguard Worker         throw "Invalid test resources directory provided";
16*89c4ff92SAndroid Build Coastguard Worker     }
17*89c4ff92SAndroid Build Coastguard Worker     else
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker         if(testResources.back() != '/')
20*89c4ff92SAndroid Build Coastguard Worker         {
21*89c4ff92SAndroid Build Coastguard Worker             return testResources + "/" + filename;
22*89c4ff92SAndroid Build Coastguard Worker         }
23*89c4ff92SAndroid Build Coastguard Worker         else
24*89c4ff92SAndroid Build Coastguard Worker         {
25*89c4ff92SAndroid Build Coastguard Worker             return testResources + filename;
26*89c4ff92SAndroid Build Coastguard Worker         }
27*89c4ff92SAndroid Build Coastguard Worker     }
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Test Network Execution SSD_MOBILE")
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker     std::string testResources = TEST_RESOURCE_DIR;
33*89c4ff92SAndroid Build Coastguard Worker     REQUIRE(testResources != "");
34*89c4ff92SAndroid Build Coastguard Worker     // Create the network options
35*89c4ff92SAndroid Build Coastguard Worker     common::PipelineOptions options;
36*89c4ff92SAndroid Build Coastguard Worker     options.m_ModelFilePath = GetResourceFilePath("ssd_mobilenet_v1.tflite");
37*89c4ff92SAndroid Build Coastguard Worker     options.m_ModelName = "SSD_MOBILE";
38*89c4ff92SAndroid Build Coastguard Worker     options.m_backends = {"CpuAcc", "CpuRef"};
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     common::InferenceResults<float> results;
43*89c4ff92SAndroid Build Coastguard Worker     cv::Mat processed;
44*89c4ff92SAndroid Build Coastguard Worker     cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
45*89c4ff92SAndroid Build Coastguard Worker     cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     objectDetectionPipeline->PreProcessing(inputFrame, processed);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     CHECK(processed.type() == CV_8UC3);
50*89c4ff92SAndroid Build Coastguard Worker     CHECK(processed.cols == 300);
51*89c4ff92SAndroid Build Coastguard Worker     CHECK(processed.rows == 300);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     objectDetectionPipeline->Inference(processed, results);
54*89c4ff92SAndroid Build Coastguard Worker     objectDetectionPipeline->PostProcessing(results,
__anon384230110102(od::DetectedObjects detects) 55*89c4ff92SAndroid Build Coastguard Worker                                             [](od::DetectedObjects detects) -> void {
56*89c4ff92SAndroid Build Coastguard Worker                                                 CHECK(detects.size() == 2);
57*89c4ff92SAndroid Build Coastguard Worker                                                 CHECK(detects[0].GetLabel() == "0");
58*89c4ff92SAndroid Build Coastguard Worker                                             });
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker }
61