1*3e777be0SXin Li // 2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*3e777be0SXin Li // SPDX-License-Identifier: MIT 4*3e777be0SXin Li // 5*3e777be0SXin Li 6*3e777be0SXin Li #include "DriverTestHelpers.hpp" 7*3e777be0SXin Li #include <log/log.h> 8*3e777be0SXin Li 9*3e777be0SXin Li #include <armnn/src/armnn/OptimizedNetworkImpl.hpp> 10*3e777be0SXin Li 11*3e777be0SXin Li #include <fstream> 12*3e777be0SXin Li #include <memory> 13*3e777be0SXin Li #include <armnn/INetwork.hpp> 14*3e777be0SXin Li 15*3e777be0SXin Li #include <armnnUtils/Filesystem.hpp> 16*3e777be0SXin Li 17*3e777be0SXin Li using namespace android; 18*3e777be0SXin Li using namespace android::nn; 19*3e777be0SXin Li using namespace android::hardware; 20*3e777be0SXin Li using namespace armnn_driver; 21*3e777be0SXin Li 22*3e777be0SXin Li namespace armnn 23*3e777be0SXin Li { 24*3e777be0SXin Li 25*3e777be0SXin Li class Graph 26*3e777be0SXin Li { 27*3e777be0SXin Li public: 28*3e777be0SXin Li Graph(Graph&& graph) = default; 29*3e777be0SXin Li }; 30*3e777be0SXin Li 31*3e777be0SXin Li class MockOptimizedNetworkImpl final : public ::armnn::OptimizedNetworkImpl 32*3e777be0SXin Li { 33*3e777be0SXin Li public: MockOptimizedNetworkImpl(const std::string & mockSerializedContent,std::unique_ptr<armnn::Graph>)34*3e777be0SXin Li MockOptimizedNetworkImpl(const std::string& mockSerializedContent, std::unique_ptr<armnn::Graph>) 35*3e777be0SXin Li : ::armnn::OptimizedNetworkImpl(nullptr) 36*3e777be0SXin Li , m_MockSerializedContent(mockSerializedContent) 37*3e777be0SXin Li {} ~MockOptimizedNetworkImpl()38*3e777be0SXin Li ~MockOptimizedNetworkImpl() {} 39*3e777be0SXin Li PrintGraph()40*3e777be0SXin Li ::armnn::Status PrintGraph() override { return ::armnn::Status::Failure; } SerializeToDot(std::ostream & stream) const41*3e777be0SXin Li ::armnn::Status SerializeToDot(std::ostream& stream) const override 42*3e777be0SXin Li { 43*3e777be0SXin Li stream << m_MockSerializedContent; 44*3e777be0SXin Li 45*3e777be0SXin Li return stream.good() ? ::armnn::Status::Success : ::armnn::Status::Failure; 46*3e777be0SXin Li } 47*3e777be0SXin Li GetGuid() const48*3e777be0SXin Li ::arm::pipe::ProfilingGuid GetGuid() const final { return ::arm::pipe::ProfilingGuid(0); } 49*3e777be0SXin Li UpdateMockSerializedContent(const std::string & mockSerializedContent)50*3e777be0SXin Li void UpdateMockSerializedContent(const std::string& mockSerializedContent) 51*3e777be0SXin Li { 52*3e777be0SXin Li this->m_MockSerializedContent = mockSerializedContent; 53*3e777be0SXin Li } 54*3e777be0SXin Li 55*3e777be0SXin Li private: 56*3e777be0SXin Li std::string m_MockSerializedContent; 57*3e777be0SXin Li }; 58*3e777be0SXin Li 59*3e777be0SXin Li 60*3e777be0SXin Li } // armnn namespace 61*3e777be0SXin Li 62*3e777be0SXin Li 63*3e777be0SXin Li // The following are helpers for writing unit tests for the driver. 64*3e777be0SXin Li namespace 65*3e777be0SXin Li { 66*3e777be0SXin Li 67*3e777be0SXin Li struct ExportNetworkGraphFixture 68*3e777be0SXin Li { 69*3e777be0SXin Li public: 70*3e777be0SXin Li // Setup: set the output dump directory and an empty dummy model (as only its memory address is used). 71*3e777be0SXin Li // Defaulting the output dump directory to "/data" because it should exist and be writable in all deployments. ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture72*3e777be0SXin Li ExportNetworkGraphFixture() 73*3e777be0SXin Li : ExportNetworkGraphFixture("/data") 74*3e777be0SXin Li {} 75*3e777be0SXin Li ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture76*3e777be0SXin Li ExportNetworkGraphFixture(const std::string& requestInputsAndOutputsDumpDir) 77*3e777be0SXin Li : m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir), m_FileName(), m_FileStream() 78*3e777be0SXin Li { 79*3e777be0SXin Li // Set the name of the output .dot file. 80*3e777be0SXin Li // NOTE: the export now uses a time stamp to name the file so we 81*3e777be0SXin Li // can't predict ahead of time what the file name will be. 82*3e777be0SXin Li std::string timestamp = "dummy"; 83*3e777be0SXin Li m_FileName = m_RequestInputsAndOutputsDumpDir / (timestamp + "_networkgraph.dot"); 84*3e777be0SXin Li } 85*3e777be0SXin Li 86*3e777be0SXin Li // Teardown: delete the dump file regardless of the outcome of the tests. ~ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture87*3e777be0SXin Li ~ExportNetworkGraphFixture() 88*3e777be0SXin Li { 89*3e777be0SXin Li // Close the file stream. 90*3e777be0SXin Li m_FileStream.close(); 91*3e777be0SXin Li 92*3e777be0SXin Li // Ignore any error (such as file not found). 93*3e777be0SXin Li (void) remove(m_FileName.c_str()); 94*3e777be0SXin Li } 95*3e777be0SXin Li FileExists__anon5503157d0111::ExportNetworkGraphFixture96*3e777be0SXin Li bool FileExists() 97*3e777be0SXin Li { 98*3e777be0SXin Li // Close any file opened in a previous session. 99*3e777be0SXin Li if (m_FileStream.is_open()) 100*3e777be0SXin Li { 101*3e777be0SXin Li m_FileStream.close(); 102*3e777be0SXin Li } 103*3e777be0SXin Li 104*3e777be0SXin Li if (m_FileName.empty()) 105*3e777be0SXin Li { 106*3e777be0SXin Li return false; 107*3e777be0SXin Li } 108*3e777be0SXin Li 109*3e777be0SXin Li // Open the file. 110*3e777be0SXin Li m_FileStream.open(m_FileName, std::ifstream::in); 111*3e777be0SXin Li 112*3e777be0SXin Li // Check that the file is open. 113*3e777be0SXin Li if (!m_FileStream.is_open()) 114*3e777be0SXin Li { 115*3e777be0SXin Li return false; 116*3e777be0SXin Li } 117*3e777be0SXin Li 118*3e777be0SXin Li // Check that the stream is readable. 119*3e777be0SXin Li return m_FileStream.good(); 120*3e777be0SXin Li } 121*3e777be0SXin Li GetFileContent__anon5503157d0111::ExportNetworkGraphFixture122*3e777be0SXin Li std::string GetFileContent() 123*3e777be0SXin Li { 124*3e777be0SXin Li // Check that the stream is readable. 125*3e777be0SXin Li if (!m_FileStream.good()) 126*3e777be0SXin Li { 127*3e777be0SXin Li return ""; 128*3e777be0SXin Li } 129*3e777be0SXin Li 130*3e777be0SXin Li // Get all the contents of the file. 131*3e777be0SXin Li return std::string((std::istreambuf_iterator<char>(m_FileStream)), 132*3e777be0SXin Li (std::istreambuf_iterator<char>())); 133*3e777be0SXin Li } 134*3e777be0SXin Li 135*3e777be0SXin Li fs::path m_RequestInputsAndOutputsDumpDir; 136*3e777be0SXin Li fs::path m_FileName; 137*3e777be0SXin Li 138*3e777be0SXin Li private: 139*3e777be0SXin Li std::ifstream m_FileStream; 140*3e777be0SXin Li }; 141*3e777be0SXin Li 142*3e777be0SXin Li 143*3e777be0SXin Li } // namespace 144*3e777be0SXin Li 145*3e777be0SXin Li DOCTEST_TEST_SUITE("UtilsTests") 146*3e777be0SXin Li { 147*3e777be0SXin Li 148*3e777be0SXin Li DOCTEST_TEST_CASE("ExportToEmptyDirectory") 149*3e777be0SXin Li { 150*3e777be0SXin Li // Set the fixture for this test. 151*3e777be0SXin Li ExportNetworkGraphFixture fixture(""); 152*3e777be0SXin Li 153*3e777be0SXin Li // Set a mock content for the optimized network. 154*3e777be0SXin Li std::string mockSerializedContent = "This is a mock serialized content."; 155*3e777be0SXin Li 156*3e777be0SXin Li // Set a mock optimized network. 157*3e777be0SXin Li std::unique_ptr<armnn::Graph> graphPtr; 158*3e777be0SXin Li 159*3e777be0SXin Li std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 160*3e777be0SXin Li new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 161*3e777be0SXin Li ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 162*3e777be0SXin Li 163*3e777be0SXin Li // Export the mock optimized network. 164*3e777be0SXin Li fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 165*3e777be0SXin Li fixture.m_RequestInputsAndOutputsDumpDir); 166*3e777be0SXin Li 167*3e777be0SXin Li // Check that the output file does not exist. 168*3e777be0SXin Li DOCTEST_CHECK(!fixture.FileExists()); 169*3e777be0SXin Li } 170*3e777be0SXin Li 171*3e777be0SXin Li DOCTEST_TEST_CASE("ExportNetwork") 172*3e777be0SXin Li { 173*3e777be0SXin Li // Set the fixture for this test. 174*3e777be0SXin Li ExportNetworkGraphFixture fixture; 175*3e777be0SXin Li 176*3e777be0SXin Li // Set a mock content for the optimized network. 177*3e777be0SXin Li std::string mockSerializedContent = "This is a mock serialized content."; 178*3e777be0SXin Li 179*3e777be0SXin Li // Set a mock optimized network. 180*3e777be0SXin Li std::unique_ptr<armnn::Graph> graphPtr; 181*3e777be0SXin Li 182*3e777be0SXin Li std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 183*3e777be0SXin Li new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 184*3e777be0SXin Li ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 185*3e777be0SXin Li 186*3e777be0SXin Li 187*3e777be0SXin Li // Export the mock optimized network. 188*3e777be0SXin Li fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 189*3e777be0SXin Li fixture.m_RequestInputsAndOutputsDumpDir); 190*3e777be0SXin Li 191*3e777be0SXin Li // Check that the output file exists and that it has the correct name. 192*3e777be0SXin Li DOCTEST_CHECK(fixture.FileExists()); 193*3e777be0SXin Li 194*3e777be0SXin Li // Check that the content of the output file matches the mock content. 195*3e777be0SXin Li DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 196*3e777be0SXin Li } 197*3e777be0SXin Li 198*3e777be0SXin Li DOCTEST_TEST_CASE("ExportNetworkOverwriteFile") 199*3e777be0SXin Li { 200*3e777be0SXin Li // Set the fixture for this test. 201*3e777be0SXin Li ExportNetworkGraphFixture fixture; 202*3e777be0SXin Li 203*3e777be0SXin Li // Set a mock content for the optimized network. 204*3e777be0SXin Li std::string mockSerializedContent = "This is a mock serialized content."; 205*3e777be0SXin Li 206*3e777be0SXin Li // Set a mock optimized network. 207*3e777be0SXin Li std::unique_ptr<armnn::Graph> graphPtr; 208*3e777be0SXin Li 209*3e777be0SXin Li std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 210*3e777be0SXin Li new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 211*3e777be0SXin Li ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 212*3e777be0SXin Li 213*3e777be0SXin Li // Export the mock optimized network. 214*3e777be0SXin Li fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 215*3e777be0SXin Li fixture.m_RequestInputsAndOutputsDumpDir); 216*3e777be0SXin Li 217*3e777be0SXin Li // Check that the output file exists and that it has the correct name. 218*3e777be0SXin Li DOCTEST_CHECK(fixture.FileExists()); 219*3e777be0SXin Li 220*3e777be0SXin Li // Check that the content of the output file matches the mock content. 221*3e777be0SXin Li DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 222*3e777be0SXin Li 223*3e777be0SXin Li // Update the mock serialized content of the network. 224*3e777be0SXin Li mockSerializedContent = "This is ANOTHER mock serialized content!"; 225*3e777be0SXin Li std::unique_ptr<armnn::Graph> graphPtr2; 226*3e777be0SXin Li std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl2( 227*3e777be0SXin Li new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr2))); 228*3e777be0SXin Li static_cast<armnn::MockOptimizedNetworkImpl*>(mockImpl2.get())->UpdateMockSerializedContent(mockSerializedContent); 229*3e777be0SXin Li ::armnn::IOptimizedNetwork mockOptimizedNetwork2(std::move(mockImpl2)); 230*3e777be0SXin Li 231*3e777be0SXin Li // Export the mock optimized network. 232*3e777be0SXin Li fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork2, 233*3e777be0SXin Li fixture.m_RequestInputsAndOutputsDumpDir); 234*3e777be0SXin Li 235*3e777be0SXin Li // Check that the output file still exists and that it has the correct name. 236*3e777be0SXin Li DOCTEST_CHECK(fixture.FileExists()); 237*3e777be0SXin Li 238*3e777be0SXin Li // Check that the content of the output file matches the mock content. 239*3e777be0SXin Li DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 240*3e777be0SXin Li } 241*3e777be0SXin Li 242*3e777be0SXin Li DOCTEST_TEST_CASE("ExportMultipleNetworks") 243*3e777be0SXin Li { 244*3e777be0SXin Li // Set the fixtures for this test. 245*3e777be0SXin Li ExportNetworkGraphFixture fixture1; 246*3e777be0SXin Li ExportNetworkGraphFixture fixture2; 247*3e777be0SXin Li ExportNetworkGraphFixture fixture3; 248*3e777be0SXin Li 249*3e777be0SXin Li // Set a mock content for the optimized network. 250*3e777be0SXin Li std::string mockSerializedContent = "This is a mock serialized content."; 251*3e777be0SXin Li 252*3e777be0SXin Li // Set a mock optimized network. 253*3e777be0SXin Li std::unique_ptr<armnn::Graph> graphPtr; 254*3e777be0SXin Li 255*3e777be0SXin Li std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 256*3e777be0SXin Li new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 257*3e777be0SXin Li ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 258*3e777be0SXin Li 259*3e777be0SXin Li // Export the mock optimized network. 260*3e777be0SXin Li fixture1.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 261*3e777be0SXin Li fixture1.m_RequestInputsAndOutputsDumpDir); 262*3e777be0SXin Li 263*3e777be0SXin Li // Check that the output file exists and that it has the correct name. 264*3e777be0SXin Li DOCTEST_CHECK(fixture1.FileExists()); 265*3e777be0SXin Li 266*3e777be0SXin Li // Check that the content of the output file matches the mock content. 267*3e777be0SXin Li DOCTEST_CHECK(fixture1.GetFileContent() == mockSerializedContent); 268*3e777be0SXin Li 269*3e777be0SXin Li // Export the mock optimized network. 270*3e777be0SXin Li fixture2.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 271*3e777be0SXin Li fixture2.m_RequestInputsAndOutputsDumpDir); 272*3e777be0SXin Li 273*3e777be0SXin Li // Check that the output file exists and that it has the correct name. 274*3e777be0SXin Li DOCTEST_CHECK(fixture2.FileExists()); 275*3e777be0SXin Li 276*3e777be0SXin Li // Check that the content of the output file matches the mock content. 277*3e777be0SXin Li DOCTEST_CHECK(fixture2.GetFileContent() == mockSerializedContent); 278*3e777be0SXin Li 279*3e777be0SXin Li // Export the mock optimized network. 280*3e777be0SXin Li fixture3.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 281*3e777be0SXin Li fixture3.m_RequestInputsAndOutputsDumpDir); 282*3e777be0SXin Li // Check that the output file exists and that it has the correct name. 283*3e777be0SXin Li DOCTEST_CHECK(fixture3.FileExists()); 284*3e777be0SXin Li 285*3e777be0SXin Li // Check that the content of the output file matches the mock content. 286*3e777be0SXin Li DOCTEST_CHECK(fixture3.GetFileContent() == mockSerializedContent); 287*3e777be0SXin Li } 288*3e777be0SXin Li 289*3e777be0SXin Li } 290