xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/MockTensorHandle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnTestUtils/MockTensorHandle.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker namespace armnn
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker 
MockTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<MockMemoryManager> & memoryManager)11*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12*89c4ff92SAndroid Build Coastguard Worker     : m_TensorInfo(tensorInfo)
13*89c4ff92SAndroid Build Coastguard Worker     , m_MemoryManager(memoryManager)
14*89c4ff92SAndroid Build Coastguard Worker     , m_Pool(nullptr)
15*89c4ff92SAndroid Build Coastguard Worker     , m_UnmanagedMemory(nullptr)
16*89c4ff92SAndroid Build Coastguard Worker     , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17*89c4ff92SAndroid Build Coastguard Worker     , m_Imported(false)
18*89c4ff92SAndroid Build Coastguard Worker     , m_IsImportEnabled(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker 
MockTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22*89c4ff92SAndroid Build Coastguard Worker     : m_TensorInfo(tensorInfo)
23*89c4ff92SAndroid Build Coastguard Worker     , m_Pool(nullptr)
24*89c4ff92SAndroid Build Coastguard Worker     , m_UnmanagedMemory(nullptr)
25*89c4ff92SAndroid Build Coastguard Worker     , m_ImportFlags(importFlags)
26*89c4ff92SAndroid Build Coastguard Worker     , m_Imported(false)
27*89c4ff92SAndroid Build Coastguard Worker     , m_IsImportEnabled(true)
28*89c4ff92SAndroid Build Coastguard Worker {}
29*89c4ff92SAndroid Build Coastguard Worker 
~MockTensorHandle()30*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::~MockTensorHandle()
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker     if (!m_Pool)
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         // unmanaged
35*89c4ff92SAndroid Build Coastguard Worker         if (!m_Imported)
36*89c4ff92SAndroid Build Coastguard Worker         {
37*89c4ff92SAndroid Build Coastguard Worker             ::operator delete(m_UnmanagedMemory);
38*89c4ff92SAndroid Build Coastguard Worker         }
39*89c4ff92SAndroid Build Coastguard Worker     }
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
Manage()42*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::Manage()
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker     if (!m_IsImportEnabled)
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50*89c4ff92SAndroid Build Coastguard Worker     }
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
Allocate()53*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::Allocate()
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker     // If import is enabled, do not allocate the tensor
56*89c4ff92SAndroid Build Coastguard Worker     if (!m_IsImportEnabled)
57*89c4ff92SAndroid Build Coastguard Worker     {
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker         if (!m_UnmanagedMemory)
60*89c4ff92SAndroid Build Coastguard Worker         {
61*89c4ff92SAndroid Build Coastguard Worker             if (!m_Pool)
62*89c4ff92SAndroid Build Coastguard Worker             {
63*89c4ff92SAndroid Build Coastguard Worker                 // unmanaged
64*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65*89c4ff92SAndroid Build Coastguard Worker             }
66*89c4ff92SAndroid Build Coastguard Worker             else
67*89c4ff92SAndroid Build Coastguard Worker             {
68*89c4ff92SAndroid Build Coastguard Worker                 m_MemoryManager->Allocate(m_Pool);
69*89c4ff92SAndroid Build Coastguard Worker             }
70*89c4ff92SAndroid Build Coastguard Worker         }
71*89c4ff92SAndroid Build Coastguard Worker         else
72*89c4ff92SAndroid Build Coastguard Worker         {
73*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74*89c4ff92SAndroid Build Coastguard Worker                                            "that already has allocated memory.");
75*89c4ff92SAndroid Build Coastguard Worker         }
76*89c4ff92SAndroid Build Coastguard Worker     }
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker 
Map(bool) const79*89c4ff92SAndroid Build Coastguard Worker const void* MockTensorHandle::Map(bool /*unused*/) const
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     return GetPointer();
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
GetPointer() const84*89c4ff92SAndroid Build Coastguard Worker void* MockTensorHandle::GetPointer() const
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker     if (m_UnmanagedMemory)
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         return m_UnmanagedMemory;
89*89c4ff92SAndroid Build Coastguard Worker     }
90*89c4ff92SAndroid Build Coastguard Worker     else if (m_Pool)
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker         return m_MemoryManager->GetPointer(m_Pool);
93*89c4ff92SAndroid Build Coastguard Worker     }
94*89c4ff92SAndroid Build Coastguard Worker     else
95*89c4ff92SAndroid Build Coastguard Worker     {
96*89c4ff92SAndroid Build Coastguard Worker         throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97*89c4ff92SAndroid Build Coastguard Worker     }
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
CopyOutTo(void * dest) const100*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::CopyOutTo(void* dest) const
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker     const void* src = GetPointer();
103*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(src);
104*89c4ff92SAndroid Build Coastguard Worker     memcpy(dest, src, m_TensorInfo.GetNumBytes());
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
CopyInFrom(const void * src)107*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::CopyInFrom(const void* src)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker     void* dest = GetPointer();
110*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(dest);
111*89c4ff92SAndroid Build Coastguard Worker     memcpy(dest, src, m_TensorInfo.GetNumBytes());
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker 
Import(void * memory,MemorySource source)114*89c4ff92SAndroid Build Coastguard Worker bool MockTensorHandle::Import(void* memory, MemorySource source)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117*89c4ff92SAndroid Build Coastguard Worker     {
118*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled && source == MemorySource::Malloc)
119*89c4ff92SAndroid Build Coastguard Worker         {
120*89c4ff92SAndroid Build Coastguard Worker             // Check memory alignment
121*89c4ff92SAndroid Build Coastguard Worker             if (!CanBeImported(memory, source))
122*89c4ff92SAndroid Build Coastguard Worker             {
123*89c4ff92SAndroid Build Coastguard Worker                 if (m_Imported)
124*89c4ff92SAndroid Build Coastguard Worker                 {
125*89c4ff92SAndroid Build Coastguard Worker                     m_Imported        = false;
126*89c4ff92SAndroid Build Coastguard Worker                     m_UnmanagedMemory = nullptr;
127*89c4ff92SAndroid Build Coastguard Worker                 }
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker                 return false;
130*89c4ff92SAndroid Build Coastguard Worker             }
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory not yet allocated.
133*89c4ff92SAndroid Build Coastguard Worker             if (!m_Imported && !m_UnmanagedMemory)
134*89c4ff92SAndroid Build Coastguard Worker             {
135*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = memory;
136*89c4ff92SAndroid Build Coastguard Worker                 m_Imported        = true;
137*89c4ff92SAndroid Build Coastguard Worker                 return true;
138*89c4ff92SAndroid Build Coastguard Worker             }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory initially allocated with Allocate().
141*89c4ff92SAndroid Build Coastguard Worker             if (!m_Imported && m_UnmanagedMemory)
142*89c4ff92SAndroid Build Coastguard Worker             {
143*89c4ff92SAndroid Build Coastguard Worker                 return false;
144*89c4ff92SAndroid Build Coastguard Worker             }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory previously imported.
147*89c4ff92SAndroid Build Coastguard Worker             if (m_Imported)
148*89c4ff92SAndroid Build Coastguard Worker             {
149*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = memory;
150*89c4ff92SAndroid Build Coastguard Worker                 return true;
151*89c4ff92SAndroid Build Coastguard Worker             }
152*89c4ff92SAndroid Build Coastguard Worker         }
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     return false;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker 
CanBeImported(void * memory,MemorySource source)158*89c4ff92SAndroid Build Coastguard Worker bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled && source == MemorySource::Malloc)
163*89c4ff92SAndroid Build Coastguard Worker         {
164*89c4ff92SAndroid Build Coastguard Worker             uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165*89c4ff92SAndroid Build Coastguard Worker             if (reinterpret_cast<uintptr_t>(memory) % alignment)
166*89c4ff92SAndroid Build Coastguard Worker             {
167*89c4ff92SAndroid Build Coastguard Worker                 return false;
168*89c4ff92SAndroid Build Coastguard Worker             }
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker             return true;
171*89c4ff92SAndroid Build Coastguard Worker         }
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker     return false;
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker }    // namespace armnn
177