1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnTestUtils/MockTensorHandle.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker namespace armnn
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker
MockTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<MockMemoryManager> & memoryManager)11*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
13*89c4ff92SAndroid Build Coastguard Worker , m_MemoryManager(memoryManager)
14*89c4ff92SAndroid Build Coastguard Worker , m_Pool(nullptr)
15*89c4ff92SAndroid Build Coastguard Worker , m_UnmanagedMemory(nullptr)
16*89c4ff92SAndroid Build Coastguard Worker , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17*89c4ff92SAndroid Build Coastguard Worker , m_Imported(false)
18*89c4ff92SAndroid Build Coastguard Worker , m_IsImportEnabled(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker
MockTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
23*89c4ff92SAndroid Build Coastguard Worker , m_Pool(nullptr)
24*89c4ff92SAndroid Build Coastguard Worker , m_UnmanagedMemory(nullptr)
25*89c4ff92SAndroid Build Coastguard Worker , m_ImportFlags(importFlags)
26*89c4ff92SAndroid Build Coastguard Worker , m_Imported(false)
27*89c4ff92SAndroid Build Coastguard Worker , m_IsImportEnabled(true)
28*89c4ff92SAndroid Build Coastguard Worker {}
29*89c4ff92SAndroid Build Coastguard Worker
~MockTensorHandle()30*89c4ff92SAndroid Build Coastguard Worker MockTensorHandle::~MockTensorHandle()
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker if (!m_Pool)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker // unmanaged
35*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported)
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker ::operator delete(m_UnmanagedMemory);
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker
Manage()42*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::Manage()
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled)
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker
Allocate()53*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::Allocate()
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker // If import is enabled, do not allocate the tensor
56*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker if (!m_UnmanagedMemory)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker if (!m_Pool)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker // unmanaged
64*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker else
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker m_MemoryManager->Allocate(m_Pool);
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker else
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74*89c4ff92SAndroid Build Coastguard Worker "that already has allocated memory.");
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker
Map(bool) const79*89c4ff92SAndroid Build Coastguard Worker const void* MockTensorHandle::Map(bool /*unused*/) const
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker return GetPointer();
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker
GetPointer() const84*89c4ff92SAndroid Build Coastguard Worker void* MockTensorHandle::GetPointer() const
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker if (m_UnmanagedMemory)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker return m_UnmanagedMemory;
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker else if (m_Pool)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker return m_MemoryManager->GetPointer(m_Pool);
93*89c4ff92SAndroid Build Coastguard Worker }
94*89c4ff92SAndroid Build Coastguard Worker else
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker
CopyOutTo(void * dest) const100*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::CopyOutTo(void* dest) const
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker const void* src = GetPointer();
103*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(src);
104*89c4ff92SAndroid Build Coastguard Worker memcpy(dest, src, m_TensorInfo.GetNumBytes());
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker
CopyInFrom(const void * src)107*89c4ff92SAndroid Build Coastguard Worker void MockTensorHandle::CopyInFrom(const void* src)
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker void* dest = GetPointer();
110*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(dest);
111*89c4ff92SAndroid Build Coastguard Worker memcpy(dest, src, m_TensorInfo.GetNumBytes());
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
Import(void * memory,MemorySource source)114*89c4ff92SAndroid Build Coastguard Worker bool MockTensorHandle::Import(void* memory, MemorySource source)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled && source == MemorySource::Malloc)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker // Check memory alignment
121*89c4ff92SAndroid Build Coastguard Worker if (!CanBeImported(memory, source))
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker if (m_Imported)
124*89c4ff92SAndroid Build Coastguard Worker {
125*89c4ff92SAndroid Build Coastguard Worker m_Imported = false;
126*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = nullptr;
127*89c4ff92SAndroid Build Coastguard Worker }
128*89c4ff92SAndroid Build Coastguard Worker
129*89c4ff92SAndroid Build Coastguard Worker return false;
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory not yet allocated.
133*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && !m_UnmanagedMemory)
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = memory;
136*89c4ff92SAndroid Build Coastguard Worker m_Imported = true;
137*89c4ff92SAndroid Build Coastguard Worker return true;
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory initially allocated with Allocate().
141*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && m_UnmanagedMemory)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker return false;
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory previously imported.
147*89c4ff92SAndroid Build Coastguard Worker if (m_Imported)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = memory;
150*89c4ff92SAndroid Build Coastguard Worker return true;
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker return false;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker
CanBeImported(void * memory,MemorySource source)158*89c4ff92SAndroid Build Coastguard Worker bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled && source == MemorySource::Malloc)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165*89c4ff92SAndroid Build Coastguard Worker if (reinterpret_cast<uintptr_t>(memory) % alignment)
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker return false;
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker
170*89c4ff92SAndroid Build Coastguard Worker return true;
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker return false;
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
177