xref: /aosp_15_r20/external/armnn/src/profiling/test/ProfilingTests.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ProfilingMocks.hpp"
9 
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 #include <client/src/IProfilingConnection.hpp>
13 #include <client/src/ProfilingService.hpp>
14 
15 #include <armnn/profiling/ArmNNProfiling.hpp>
16 
17 #include <common/include/CommandHandlerFunctor.hpp>
18 #include <common/include/Logging.hpp>
19 
20 #include <doctest/doctest.h>
21 
22 #include <chrono>
23 #include <thread>
24 
25 namespace arm
26 {
27 
28 namespace pipe
29 {
30 
31 class TestProfilingConnectionBase : public IProfilingConnection
32 {
33 public:
34     TestProfilingConnectionBase() = default;
35     ~TestProfilingConnectionBase() = default;
36 
IsOpen() const37     bool IsOpen() const override { return true; }
38 
Close()39     void Close() override {}
40 
WritePacket(const unsigned char * buffer,uint32_t length)41     bool WritePacket(const unsigned char* buffer, uint32_t length) override
42     {
43         arm::pipe::IgnoreUnused(buffer, length);
44 
45         return false;
46     }
47 
ReadPacket(uint32_t timeout)48     arm::pipe::Packet ReadPacket(uint32_t timeout) override
49     {
50         // First time we're called return a connection ack packet. After that always timeout.
51         if (m_FirstCall)
52         {
53             m_FirstCall = false;
54             // Return connection acknowledged packet
55             return arm::pipe::Packet(65536);
56         }
57         else
58         {
59             std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
60             throw arm::pipe::TimeoutException("Simulate a timeout error\n");
61         }
62     }
63 
64     bool m_FirstCall = true;
65 };
66 
67 class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase
68 {
69 public:
TestProfilingConnectionTimeoutError()70     TestProfilingConnectionTimeoutError()
71         : m_ReadRequests(0)
72     {}
73 
ReadPacket(uint32_t timeout)74     arm::pipe::Packet ReadPacket(uint32_t timeout) override
75     {
76         // Return connection acknowledged packet after three timeouts
77         if (m_ReadRequests % 3 == 0)
78         {
79             std::this_thread::sleep_for(std::chrono::milliseconds(timeout));
80             ++m_ReadRequests;
81             throw arm::pipe::TimeoutException("Simulate a timeout error\n");
82         }
83 
84         return arm::pipe::Packet(65536);
85     }
86 
ReadCalledCount()87     int ReadCalledCount()
88     {
89         return m_ReadRequests.load();
90     }
91 
92 private:
93     std::atomic<int> m_ReadRequests;
94 };
95 
96 class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase
97 {
98 public:
TestProfilingConnectionArmnnError()99     TestProfilingConnectionArmnnError()
100         : m_ReadRequests(0)
101     {}
102 
ReadPacket(uint32_t timeout)103     arm::pipe::Packet ReadPacket(uint32_t timeout) override
104     {
105         arm::pipe::IgnoreUnused(timeout);
106         ++m_ReadRequests;
107         throw arm::pipe::ProfilingException("Simulate a non-timeout error");
108     }
109 
ReadCalledCount()110     int ReadCalledCount()
111     {
112         return m_ReadRequests.load();
113     }
114 
115 private:
116     std::atomic<int> m_ReadRequests;
117 };
118 
119 class TestProfilingConnectionBadAckPacket : public TestProfilingConnectionBase
120 {
121 public:
ReadPacket(uint32_t timeout)122     arm::pipe::Packet ReadPacket(uint32_t timeout) override
123     {
124         arm::pipe::IgnoreUnused(timeout);
125         // Connection Acknowledged Packet header (word 0, word 1 is always zero):
126         // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
127         // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
128         // 8:15  [8]  reserved: Reserved, value 0b00000000
129         // 0:7   [8]  reserved: Reserved, value 0b00000000
130         uint32_t packetFamily = 0;
131         uint32_t packetId     = 37;    // Wrong packet id!!!
132         uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
133 
134         return arm::pipe::Packet(header);
135     }
136 };
137 
138 class TestFunctorA : public arm::pipe::CommandHandlerFunctor
139 {
140 public:
141     using CommandHandlerFunctor::CommandHandlerFunctor;
142 
GetCount()143     int GetCount() { return m_Count; }
144 
operator ()(const arm::pipe::Packet & packet)145     void operator()(const arm::pipe::Packet& packet) override
146     {
147         arm::pipe::IgnoreUnused(packet);
148         m_Count++;
149     }
150 
151 private:
152     int m_Count = 0;
153 };
154 
155 class TestFunctorB : public TestFunctorA
156 {
157     using TestFunctorA::TestFunctorA;
158 };
159 
160 class TestFunctorC : public TestFunctorA
161 {
162     using TestFunctorA::TestFunctorA;
163 };
164 
165 class SwapProfilingConnectionFactoryHelper : public ProfilingService
166 {
167 public:
168     using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>;
169 
SwapProfilingConnectionFactoryHelper(uint16_t maxGlobalCounterId,IInitialiseProfilingService & initialiser,ProfilingService & profilingService)170     SwapProfilingConnectionFactoryHelper(uint16_t maxGlobalCounterId,
171                                          IInitialiseProfilingService& initialiser,
172                                          ProfilingService& profilingService)
173         : ProfilingService(maxGlobalCounterId,
174                            initialiser,
175                            arm::pipe::ARMNN_SOFTWARE_INFO,
176                            arm::pipe::ARMNN_SOFTWARE_VERSION,
177                            arm::pipe::ARMNN_HARDWARE_VERSION)
178         , m_ProfilingService(profilingService)
179         , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory())
180         , m_BackupProfilingConnectionFactory(nullptr)
181 
182     {
183         CHECK(m_MockProfilingConnectionFactory);
184         SwapProfilingConnectionFactory(m_ProfilingService,
185                                        m_MockProfilingConnectionFactory.get(),
186                                        m_BackupProfilingConnectionFactory);
187         CHECK(m_BackupProfilingConnectionFactory);
188     }
~SwapProfilingConnectionFactoryHelper()189     ~SwapProfilingConnectionFactoryHelper()
190     {
191         CHECK(m_BackupProfilingConnectionFactory);
192         IProfilingConnectionFactory* temp = nullptr;
193         SwapProfilingConnectionFactory(m_ProfilingService,
194                                        m_BackupProfilingConnectionFactory,
195                                        temp);
196     }
197 
GetMockProfilingConnection()198     MockProfilingConnection* GetMockProfilingConnection()
199     {
200         IProfilingConnection* profilingConnection = GetProfilingConnection(m_ProfilingService);
201         return armnn::PolymorphicDowncast<MockProfilingConnection*>(profilingConnection);
202     }
203 
ForceTransitionToState(ProfilingState newState)204     void ForceTransitionToState(ProfilingState newState)
205     {
206         TransitionToState(m_ProfilingService, newState);
207     }
208 
WaitForPacketsSent(MockProfilingConnection * mockProfilingConnection,MockProfilingConnection::PacketType packetType,uint32_t length=0,uint32_t timeout=1000)209     long WaitForPacketsSent(MockProfilingConnection* mockProfilingConnection,
210                             MockProfilingConnection::PacketType packetType,
211                             uint32_t length = 0,
212                             uint32_t timeout  = 1000)
213     {
214         long packetCount = mockProfilingConnection->CheckForPacket({ packetType, length });
215         // The first packet we receive may not be the one we are looking for, so keep looping until till we find it,
216         // or until WaitForPacketsSent times out
217         while(packetCount == 0 && timeout != 0)
218         {
219             std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
220             // Wait for a notification from the send thread
221             ProfilingService::WaitForPacketSent(m_ProfilingService, timeout);
222 
223             std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
224 
225             // We need to make sure the timeout does not reset each time we call WaitForPacketsSent
226             uint32_t elapsedTime = static_cast<uint32_t>(
227                     std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count());
228 
229             packetCount = mockProfilingConnection->CheckForPacket({packetType, length});
230 
231             if (elapsedTime > timeout)
232             {
233                 break;
234             }
235 
236             timeout -= elapsedTime;
237         }
238         return packetCount;
239     }
240 
241 private:
242     ProfilingService& m_ProfilingService;
243     MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory;
244     IProfilingConnectionFactory* m_BackupProfilingConnectionFactory;
245 };
246 
247 } // namespace pipe
248 
249 } // namespace arm
250