1 // 2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "ProfilingMocks.hpp" 9 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 #include <client/src/IProfilingConnection.hpp> 13 #include <client/src/ProfilingService.hpp> 14 15 #include <armnn/profiling/ArmNNProfiling.hpp> 16 17 #include <common/include/CommandHandlerFunctor.hpp> 18 #include <common/include/Logging.hpp> 19 20 #include <doctest/doctest.h> 21 22 #include <chrono> 23 #include <thread> 24 25 namespace arm 26 { 27 28 namespace pipe 29 { 30 31 class TestProfilingConnectionBase : public IProfilingConnection 32 { 33 public: 34 TestProfilingConnectionBase() = default; 35 ~TestProfilingConnectionBase() = default; 36 IsOpen() const37 bool IsOpen() const override { return true; } 38 Close()39 void Close() override {} 40 WritePacket(const unsigned char * buffer,uint32_t length)41 bool WritePacket(const unsigned char* buffer, uint32_t length) override 42 { 43 arm::pipe::IgnoreUnused(buffer, length); 44 45 return false; 46 } 47 ReadPacket(uint32_t timeout)48 arm::pipe::Packet ReadPacket(uint32_t timeout) override 49 { 50 // First time we're called return a connection ack packet. After that always timeout. 51 if (m_FirstCall) 52 { 53 m_FirstCall = false; 54 // Return connection acknowledged packet 55 return arm::pipe::Packet(65536); 56 } 57 else 58 { 59 std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); 60 throw arm::pipe::TimeoutException("Simulate a timeout error\n"); 61 } 62 } 63 64 bool m_FirstCall = true; 65 }; 66 67 class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase 68 { 69 public: TestProfilingConnectionTimeoutError()70 TestProfilingConnectionTimeoutError() 71 : m_ReadRequests(0) 72 {} 73 ReadPacket(uint32_t timeout)74 arm::pipe::Packet ReadPacket(uint32_t timeout) override 75 { 76 // Return connection acknowledged packet after three timeouts 77 if (m_ReadRequests % 3 == 0) 78 { 79 std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); 80 ++m_ReadRequests; 81 throw arm::pipe::TimeoutException("Simulate a timeout error\n"); 82 } 83 84 return arm::pipe::Packet(65536); 85 } 86 ReadCalledCount()87 int ReadCalledCount() 88 { 89 return m_ReadRequests.load(); 90 } 91 92 private: 93 std::atomic<int> m_ReadRequests; 94 }; 95 96 class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase 97 { 98 public: TestProfilingConnectionArmnnError()99 TestProfilingConnectionArmnnError() 100 : m_ReadRequests(0) 101 {} 102 ReadPacket(uint32_t timeout)103 arm::pipe::Packet ReadPacket(uint32_t timeout) override 104 { 105 arm::pipe::IgnoreUnused(timeout); 106 ++m_ReadRequests; 107 throw arm::pipe::ProfilingException("Simulate a non-timeout error"); 108 } 109 ReadCalledCount()110 int ReadCalledCount() 111 { 112 return m_ReadRequests.load(); 113 } 114 115 private: 116 std::atomic<int> m_ReadRequests; 117 }; 118 119 class TestProfilingConnectionBadAckPacket : public TestProfilingConnectionBase 120 { 121 public: ReadPacket(uint32_t timeout)122 arm::pipe::Packet ReadPacket(uint32_t timeout) override 123 { 124 arm::pipe::IgnoreUnused(timeout); 125 // Connection Acknowledged Packet header (word 0, word 1 is always zero): 126 // 26:31 [6] packet_family: Control Packet Family, value 0b000000 127 // 16:25 [10] packet_id: Packet identifier, value 0b0000000001 128 // 8:15 [8] reserved: Reserved, value 0b00000000 129 // 0:7 [8] reserved: Reserved, value 0b00000000 130 uint32_t packetFamily = 0; 131 uint32_t packetId = 37; // Wrong packet id!!! 132 uint32_t header = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16); 133 134 return arm::pipe::Packet(header); 135 } 136 }; 137 138 class TestFunctorA : public arm::pipe::CommandHandlerFunctor 139 { 140 public: 141 using CommandHandlerFunctor::CommandHandlerFunctor; 142 GetCount()143 int GetCount() { return m_Count; } 144 operator ()(const arm::pipe::Packet & packet)145 void operator()(const arm::pipe::Packet& packet) override 146 { 147 arm::pipe::IgnoreUnused(packet); 148 m_Count++; 149 } 150 151 private: 152 int m_Count = 0; 153 }; 154 155 class TestFunctorB : public TestFunctorA 156 { 157 using TestFunctorA::TestFunctorA; 158 }; 159 160 class TestFunctorC : public TestFunctorA 161 { 162 using TestFunctorA::TestFunctorA; 163 }; 164 165 class SwapProfilingConnectionFactoryHelper : public ProfilingService 166 { 167 public: 168 using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>; 169 SwapProfilingConnectionFactoryHelper(uint16_t maxGlobalCounterId,IInitialiseProfilingService & initialiser,ProfilingService & profilingService)170 SwapProfilingConnectionFactoryHelper(uint16_t maxGlobalCounterId, 171 IInitialiseProfilingService& initialiser, 172 ProfilingService& profilingService) 173 : ProfilingService(maxGlobalCounterId, 174 initialiser, 175 arm::pipe::ARMNN_SOFTWARE_INFO, 176 arm::pipe::ARMNN_SOFTWARE_VERSION, 177 arm::pipe::ARMNN_HARDWARE_VERSION) 178 , m_ProfilingService(profilingService) 179 , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) 180 , m_BackupProfilingConnectionFactory(nullptr) 181 182 { 183 CHECK(m_MockProfilingConnectionFactory); 184 SwapProfilingConnectionFactory(m_ProfilingService, 185 m_MockProfilingConnectionFactory.get(), 186 m_BackupProfilingConnectionFactory); 187 CHECK(m_BackupProfilingConnectionFactory); 188 } ~SwapProfilingConnectionFactoryHelper()189 ~SwapProfilingConnectionFactoryHelper() 190 { 191 CHECK(m_BackupProfilingConnectionFactory); 192 IProfilingConnectionFactory* temp = nullptr; 193 SwapProfilingConnectionFactory(m_ProfilingService, 194 m_BackupProfilingConnectionFactory, 195 temp); 196 } 197 GetMockProfilingConnection()198 MockProfilingConnection* GetMockProfilingConnection() 199 { 200 IProfilingConnection* profilingConnection = GetProfilingConnection(m_ProfilingService); 201 return armnn::PolymorphicDowncast<MockProfilingConnection*>(profilingConnection); 202 } 203 ForceTransitionToState(ProfilingState newState)204 void ForceTransitionToState(ProfilingState newState) 205 { 206 TransitionToState(m_ProfilingService, newState); 207 } 208 WaitForPacketsSent(MockProfilingConnection * mockProfilingConnection,MockProfilingConnection::PacketType packetType,uint32_t length=0,uint32_t timeout=1000)209 long WaitForPacketsSent(MockProfilingConnection* mockProfilingConnection, 210 MockProfilingConnection::PacketType packetType, 211 uint32_t length = 0, 212 uint32_t timeout = 1000) 213 { 214 long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length }); 215 // The first packet we receive may not be the one we are looking for, so keep looping until till we find it, 216 // or until WaitForPacketsSent times out 217 while(packetCount == 0 && timeout != 0) 218 { 219 std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now(); 220 // Wait for a notification from the send thread 221 ProfilingService::WaitForPacketSent(m_ProfilingService, timeout); 222 223 std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now(); 224 225 // We need to make sure the timeout does not reset each time we call WaitForPacketsSent 226 uint32_t elapsedTime = static_cast<uint32_t>( 227 std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()); 228 229 packetCount = mockProfilingConnection->CheckForPacket({packetType, length}); 230 231 if (elapsedTime > timeout) 232 { 233 break; 234 } 235 236 timeout -= elapsedTime; 237 } 238 return packetCount; 239 } 240 241 private: 242 ProfilingService& m_ProfilingService; 243 MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; 244 IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; 245 }; 246 247 } // namespace pipe 248 249 } // namespace arm 250