xref: /aosp_15_r20/external/android-nn-driver/test/UtilsTests.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #include "DriverTestHelpers.hpp"
7*3e777be0SXin Li #include <log/log.h>
8*3e777be0SXin Li 
9*3e777be0SXin Li #include <armnn/src/armnn/OptimizedNetworkImpl.hpp>
10*3e777be0SXin Li 
11*3e777be0SXin Li #include <fstream>
12*3e777be0SXin Li #include <memory>
13*3e777be0SXin Li #include <armnn/INetwork.hpp>
14*3e777be0SXin Li 
15*3e777be0SXin Li #include <armnnUtils/Filesystem.hpp>
16*3e777be0SXin Li 
17*3e777be0SXin Li using namespace android;
18*3e777be0SXin Li using namespace android::nn;
19*3e777be0SXin Li using namespace android::hardware;
20*3e777be0SXin Li using namespace armnn_driver;
21*3e777be0SXin Li 
22*3e777be0SXin Li namespace armnn
23*3e777be0SXin Li {
24*3e777be0SXin Li 
25*3e777be0SXin Li class Graph
26*3e777be0SXin Li {
27*3e777be0SXin Li public:
28*3e777be0SXin Li     Graph(Graph&& graph) = default;
29*3e777be0SXin Li };
30*3e777be0SXin Li 
31*3e777be0SXin Li class MockOptimizedNetworkImpl final : public ::armnn::OptimizedNetworkImpl
32*3e777be0SXin Li {
33*3e777be0SXin Li public:
MockOptimizedNetworkImpl(const std::string & mockSerializedContent,std::unique_ptr<armnn::Graph>)34*3e777be0SXin Li     MockOptimizedNetworkImpl(const std::string& mockSerializedContent, std::unique_ptr<armnn::Graph>)
35*3e777be0SXin Li         : ::armnn::OptimizedNetworkImpl(nullptr)
36*3e777be0SXin Li         , m_MockSerializedContent(mockSerializedContent)
37*3e777be0SXin Li     {}
~MockOptimizedNetworkImpl()38*3e777be0SXin Li     ~MockOptimizedNetworkImpl() {}
39*3e777be0SXin Li 
PrintGraph()40*3e777be0SXin Li     ::armnn::Status PrintGraph() override { return ::armnn::Status::Failure; }
SerializeToDot(std::ostream & stream) const41*3e777be0SXin Li     ::armnn::Status SerializeToDot(std::ostream& stream) const override
42*3e777be0SXin Li     {
43*3e777be0SXin Li         stream << m_MockSerializedContent;
44*3e777be0SXin Li 
45*3e777be0SXin Li         return stream.good() ? ::armnn::Status::Success : ::armnn::Status::Failure;
46*3e777be0SXin Li     }
47*3e777be0SXin Li 
GetGuid() const48*3e777be0SXin Li     ::arm::pipe::ProfilingGuid GetGuid() const final { return ::arm::pipe::ProfilingGuid(0); }
49*3e777be0SXin Li 
UpdateMockSerializedContent(const std::string & mockSerializedContent)50*3e777be0SXin Li     void UpdateMockSerializedContent(const std::string& mockSerializedContent)
51*3e777be0SXin Li     {
52*3e777be0SXin Li         this->m_MockSerializedContent = mockSerializedContent;
53*3e777be0SXin Li     }
54*3e777be0SXin Li 
55*3e777be0SXin Li private:
56*3e777be0SXin Li     std::string m_MockSerializedContent;
57*3e777be0SXin Li };
58*3e777be0SXin Li 
59*3e777be0SXin Li 
60*3e777be0SXin Li } // armnn namespace
61*3e777be0SXin Li 
62*3e777be0SXin Li 
63*3e777be0SXin Li // The following are helpers for writing unit tests for the driver.
64*3e777be0SXin Li namespace
65*3e777be0SXin Li {
66*3e777be0SXin Li 
67*3e777be0SXin Li struct ExportNetworkGraphFixture
68*3e777be0SXin Li {
69*3e777be0SXin Li public:
70*3e777be0SXin Li     // Setup: set the output dump directory and an empty dummy model (as only its memory address is used).
71*3e777be0SXin Li     // Defaulting the output dump directory to "/data" because it should exist and be writable in all deployments.
ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture72*3e777be0SXin Li     ExportNetworkGraphFixture()
73*3e777be0SXin Li         : ExportNetworkGraphFixture("/data")
74*3e777be0SXin Li     {}
75*3e777be0SXin Li 
ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture76*3e777be0SXin Li     ExportNetworkGraphFixture(const std::string& requestInputsAndOutputsDumpDir)
77*3e777be0SXin Li         : m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir), m_FileName(), m_FileStream()
78*3e777be0SXin Li     {
79*3e777be0SXin Li         // Set the name of the output .dot file.
80*3e777be0SXin Li         // NOTE: the export now uses a time stamp to name the file so we
81*3e777be0SXin Li         //       can't predict ahead of time what the file name will be.
82*3e777be0SXin Li         std::string timestamp = "dummy";
83*3e777be0SXin Li         m_FileName = m_RequestInputsAndOutputsDumpDir / (timestamp + "_networkgraph.dot");
84*3e777be0SXin Li     }
85*3e777be0SXin Li 
86*3e777be0SXin Li     // Teardown: delete the dump file regardless of the outcome of the tests.
~ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture87*3e777be0SXin Li     ~ExportNetworkGraphFixture()
88*3e777be0SXin Li     {
89*3e777be0SXin Li         // Close the file stream.
90*3e777be0SXin Li         m_FileStream.close();
91*3e777be0SXin Li 
92*3e777be0SXin Li         // Ignore any error (such as file not found).
93*3e777be0SXin Li         (void) remove(m_FileName.c_str());
94*3e777be0SXin Li     }
95*3e777be0SXin Li 
FileExists__anon5503157d0111::ExportNetworkGraphFixture96*3e777be0SXin Li     bool FileExists()
97*3e777be0SXin Li     {
98*3e777be0SXin Li         // Close any file opened in a previous session.
99*3e777be0SXin Li         if (m_FileStream.is_open())
100*3e777be0SXin Li         {
101*3e777be0SXin Li             m_FileStream.close();
102*3e777be0SXin Li         }
103*3e777be0SXin Li 
104*3e777be0SXin Li         if (m_FileName.empty())
105*3e777be0SXin Li         {
106*3e777be0SXin Li             return false;
107*3e777be0SXin Li         }
108*3e777be0SXin Li 
109*3e777be0SXin Li         // Open the file.
110*3e777be0SXin Li         m_FileStream.open(m_FileName, std::ifstream::in);
111*3e777be0SXin Li 
112*3e777be0SXin Li         // Check that the file is open.
113*3e777be0SXin Li         if (!m_FileStream.is_open())
114*3e777be0SXin Li         {
115*3e777be0SXin Li             return false;
116*3e777be0SXin Li         }
117*3e777be0SXin Li 
118*3e777be0SXin Li         // Check that the stream is readable.
119*3e777be0SXin Li         return m_FileStream.good();
120*3e777be0SXin Li     }
121*3e777be0SXin Li 
GetFileContent__anon5503157d0111::ExportNetworkGraphFixture122*3e777be0SXin Li     std::string GetFileContent()
123*3e777be0SXin Li     {
124*3e777be0SXin Li         // Check that the stream is readable.
125*3e777be0SXin Li         if (!m_FileStream.good())
126*3e777be0SXin Li         {
127*3e777be0SXin Li             return "";
128*3e777be0SXin Li         }
129*3e777be0SXin Li 
130*3e777be0SXin Li         // Get all the contents of the file.
131*3e777be0SXin Li         return std::string((std::istreambuf_iterator<char>(m_FileStream)),
132*3e777be0SXin Li                            (std::istreambuf_iterator<char>()));
133*3e777be0SXin Li     }
134*3e777be0SXin Li 
135*3e777be0SXin Li     fs::path m_RequestInputsAndOutputsDumpDir;
136*3e777be0SXin Li     fs::path m_FileName;
137*3e777be0SXin Li 
138*3e777be0SXin Li private:
139*3e777be0SXin Li     std::ifstream m_FileStream;
140*3e777be0SXin Li };
141*3e777be0SXin Li 
142*3e777be0SXin Li 
143*3e777be0SXin Li } // namespace
144*3e777be0SXin Li 
145*3e777be0SXin Li DOCTEST_TEST_SUITE("UtilsTests")
146*3e777be0SXin Li {
147*3e777be0SXin Li 
148*3e777be0SXin Li DOCTEST_TEST_CASE("ExportToEmptyDirectory")
149*3e777be0SXin Li {
150*3e777be0SXin Li     // Set the fixture for this test.
151*3e777be0SXin Li     ExportNetworkGraphFixture fixture("");
152*3e777be0SXin Li 
153*3e777be0SXin Li     // Set a mock content for the optimized network.
154*3e777be0SXin Li     std::string mockSerializedContent = "This is a mock serialized content.";
155*3e777be0SXin Li 
156*3e777be0SXin Li     // Set a mock optimized network.
157*3e777be0SXin Li     std::unique_ptr<armnn::Graph> graphPtr;
158*3e777be0SXin Li 
159*3e777be0SXin Li     std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl(
160*3e777be0SXin Li         new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr)));
161*3e777be0SXin Li     ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl));
162*3e777be0SXin Li 
163*3e777be0SXin Li     // Export the mock optimized network.
164*3e777be0SXin Li     fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
165*3e777be0SXin Li                                                                    fixture.m_RequestInputsAndOutputsDumpDir);
166*3e777be0SXin Li 
167*3e777be0SXin Li     // Check that the output file does not exist.
168*3e777be0SXin Li     DOCTEST_CHECK(!fixture.FileExists());
169*3e777be0SXin Li }
170*3e777be0SXin Li 
171*3e777be0SXin Li DOCTEST_TEST_CASE("ExportNetwork")
172*3e777be0SXin Li {
173*3e777be0SXin Li     // Set the fixture for this test.
174*3e777be0SXin Li     ExportNetworkGraphFixture fixture;
175*3e777be0SXin Li 
176*3e777be0SXin Li     // Set a mock content for the optimized network.
177*3e777be0SXin Li     std::string mockSerializedContent = "This is a mock serialized content.";
178*3e777be0SXin Li 
179*3e777be0SXin Li     // Set a mock optimized network.
180*3e777be0SXin Li     std::unique_ptr<armnn::Graph> graphPtr;
181*3e777be0SXin Li 
182*3e777be0SXin Li     std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl(
183*3e777be0SXin Li         new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr)));
184*3e777be0SXin Li     ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl));
185*3e777be0SXin Li 
186*3e777be0SXin Li 
187*3e777be0SXin Li     // Export the mock optimized network.
188*3e777be0SXin Li     fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
189*3e777be0SXin Li                                                                    fixture.m_RequestInputsAndOutputsDumpDir);
190*3e777be0SXin Li 
191*3e777be0SXin Li     // Check that the output file exists and that it has the correct name.
192*3e777be0SXin Li     DOCTEST_CHECK(fixture.FileExists());
193*3e777be0SXin Li 
194*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
195*3e777be0SXin Li     DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent);
196*3e777be0SXin Li }
197*3e777be0SXin Li 
198*3e777be0SXin Li DOCTEST_TEST_CASE("ExportNetworkOverwriteFile")
199*3e777be0SXin Li {
200*3e777be0SXin Li     // Set the fixture for this test.
201*3e777be0SXin Li     ExportNetworkGraphFixture fixture;
202*3e777be0SXin Li 
203*3e777be0SXin Li     // Set a mock content for the optimized network.
204*3e777be0SXin Li     std::string mockSerializedContent = "This is a mock serialized content.";
205*3e777be0SXin Li 
206*3e777be0SXin Li     // Set a mock optimized network.
207*3e777be0SXin Li     std::unique_ptr<armnn::Graph> graphPtr;
208*3e777be0SXin Li 
209*3e777be0SXin Li     std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl(
210*3e777be0SXin Li         new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr)));
211*3e777be0SXin Li     ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl));
212*3e777be0SXin Li 
213*3e777be0SXin Li     // Export the mock optimized network.
214*3e777be0SXin Li     fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
215*3e777be0SXin Li                                                                    fixture.m_RequestInputsAndOutputsDumpDir);
216*3e777be0SXin Li 
217*3e777be0SXin Li     // Check that the output file exists and that it has the correct name.
218*3e777be0SXin Li     DOCTEST_CHECK(fixture.FileExists());
219*3e777be0SXin Li 
220*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
221*3e777be0SXin Li     DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent);
222*3e777be0SXin Li 
223*3e777be0SXin Li     // Update the mock serialized content of the network.
224*3e777be0SXin Li     mockSerializedContent = "This is ANOTHER mock serialized content!";
225*3e777be0SXin Li     std::unique_ptr<armnn::Graph> graphPtr2;
226*3e777be0SXin Li     std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl2(
227*3e777be0SXin Li         new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr2)));
228*3e777be0SXin Li     static_cast<armnn::MockOptimizedNetworkImpl*>(mockImpl2.get())->UpdateMockSerializedContent(mockSerializedContent);
229*3e777be0SXin Li     ::armnn::IOptimizedNetwork mockOptimizedNetwork2(std::move(mockImpl2));
230*3e777be0SXin Li 
231*3e777be0SXin Li     // Export the mock optimized network.
232*3e777be0SXin Li     fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork2,
233*3e777be0SXin Li                                                                    fixture.m_RequestInputsAndOutputsDumpDir);
234*3e777be0SXin Li 
235*3e777be0SXin Li     // Check that the output file still exists and that it has the correct name.
236*3e777be0SXin Li     DOCTEST_CHECK(fixture.FileExists());
237*3e777be0SXin Li 
238*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
239*3e777be0SXin Li     DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent);
240*3e777be0SXin Li }
241*3e777be0SXin Li 
242*3e777be0SXin Li DOCTEST_TEST_CASE("ExportMultipleNetworks")
243*3e777be0SXin Li {
244*3e777be0SXin Li     // Set the fixtures for this test.
245*3e777be0SXin Li     ExportNetworkGraphFixture fixture1;
246*3e777be0SXin Li     ExportNetworkGraphFixture fixture2;
247*3e777be0SXin Li     ExportNetworkGraphFixture fixture3;
248*3e777be0SXin Li 
249*3e777be0SXin Li     // Set a mock content for the optimized network.
250*3e777be0SXin Li     std::string mockSerializedContent = "This is a mock serialized content.";
251*3e777be0SXin Li 
252*3e777be0SXin Li     // Set a mock optimized network.
253*3e777be0SXin Li     std::unique_ptr<armnn::Graph> graphPtr;
254*3e777be0SXin Li 
255*3e777be0SXin Li     std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl(
256*3e777be0SXin Li         new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr)));
257*3e777be0SXin Li     ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl));
258*3e777be0SXin Li 
259*3e777be0SXin Li     // Export the mock optimized network.
260*3e777be0SXin Li     fixture1.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
261*3e777be0SXin Li                                                                     fixture1.m_RequestInputsAndOutputsDumpDir);
262*3e777be0SXin Li 
263*3e777be0SXin Li     // Check that the output file exists and that it has the correct name.
264*3e777be0SXin Li     DOCTEST_CHECK(fixture1.FileExists());
265*3e777be0SXin Li 
266*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
267*3e777be0SXin Li     DOCTEST_CHECK(fixture1.GetFileContent() == mockSerializedContent);
268*3e777be0SXin Li 
269*3e777be0SXin Li     // Export the mock optimized network.
270*3e777be0SXin Li     fixture2.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
271*3e777be0SXin Li                                                                     fixture2.m_RequestInputsAndOutputsDumpDir);
272*3e777be0SXin Li 
273*3e777be0SXin Li     // Check that the output file exists and that it has the correct name.
274*3e777be0SXin Li     DOCTEST_CHECK(fixture2.FileExists());
275*3e777be0SXin Li 
276*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
277*3e777be0SXin Li     DOCTEST_CHECK(fixture2.GetFileContent() == mockSerializedContent);
278*3e777be0SXin Li 
279*3e777be0SXin Li     // Export the mock optimized network.
280*3e777be0SXin Li     fixture3.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork,
281*3e777be0SXin Li                                                                     fixture3.m_RequestInputsAndOutputsDumpDir);
282*3e777be0SXin Li     // Check that the output file exists and that it has the correct name.
283*3e777be0SXin Li     DOCTEST_CHECK(fixture3.FileExists());
284*3e777be0SXin Li 
285*3e777be0SXin Li     // Check that the content of the output file matches the mock content.
286*3e777be0SXin Li     DOCTEST_CHECK(fixture3.GetFileContent() == mockSerializedContent);
287*3e777be0SXin Li }
288*3e777be0SXin Li 
289*3e777be0SXin Li }
290