1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefTensorHandle.hpp"
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker namespace armnn
8*89c4ff92SAndroid Build Coastguard Worker {
9*89c4ff92SAndroid Build Coastguard Worker
TosaRefTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<TosaRefMemoryManager> & memoryManager)10*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
11*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<TosaRefMemoryManager>& memoryManager)
12*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
13*89c4ff92SAndroid Build Coastguard Worker , m_MemoryManager(memoryManager)
14*89c4ff92SAndroid Build Coastguard Worker , m_Pool(nullptr)
15*89c4ff92SAndroid Build Coastguard Worker , m_UnmanagedMemory(nullptr)
16*89c4ff92SAndroid Build Coastguard Worker , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17*89c4ff92SAndroid Build Coastguard Worker , m_Imported(false)
18*89c4ff92SAndroid Build Coastguard Worker , m_IsImportEnabled(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker
TosaRefTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
22*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags)
23*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
24*89c4ff92SAndroid Build Coastguard Worker , m_Pool(nullptr)
25*89c4ff92SAndroid Build Coastguard Worker , m_UnmanagedMemory(nullptr)
26*89c4ff92SAndroid Build Coastguard Worker , m_ImportFlags(importFlags)
27*89c4ff92SAndroid Build Coastguard Worker , m_Imported(false)
28*89c4ff92SAndroid Build Coastguard Worker , m_IsImportEnabled(true)
29*89c4ff92SAndroid Build Coastguard Worker {}
30*89c4ff92SAndroid Build Coastguard Worker
~TosaRefTensorHandle()31*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::~TosaRefTensorHandle()
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker if (!m_Pool)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker // unmanaged
36*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported)
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker ::operator delete(m_UnmanagedMemory);
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
Manage()43*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::Manage()
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(!m_Pool, "TosaRefTensorHandle::Manage() called twice");
48*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "TosaRefTensorHandle::Manage() called after Allocate()");
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker
Allocate()54*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::Allocate()
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker // If import is enabled, do not allocate the tensor
57*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker if (!m_UnmanagedMemory)
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker if (!m_Pool)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker // unmanaged
65*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker else
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker m_MemoryManager->Allocate(m_Pool);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker else
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("TosaRefTensorHandle::Allocate Trying to allocate a TosaRefTensorHandle"
75*89c4ff92SAndroid Build Coastguard Worker "that already has allocated memory.");
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker
Map(bool) const80*89c4ff92SAndroid Build Coastguard Worker const void* TosaRefTensorHandle::Map(bool /*unused*/) const
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker return GetPointer();
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker
GetPointer() const85*89c4ff92SAndroid Build Coastguard Worker void* TosaRefTensorHandle::GetPointer() const
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker if (m_UnmanagedMemory)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker return m_UnmanagedMemory;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker else if (m_Pool)
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker return m_MemoryManager->GetPointer(m_Pool);
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker else
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker throw NullPointerException("TosaRefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker
CopyOutTo(void * dest) const101*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::CopyOutTo(void* dest) const
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker const void *src = GetPointer();
104*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(src);
105*89c4ff92SAndroid Build Coastguard Worker memcpy(dest, src, m_TensorInfo.GetNumBytes());
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker
CopyInFrom(const void * src)108*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::CopyInFrom(const void* src)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker void *dest = GetPointer();
111*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(dest);
112*89c4ff92SAndroid Build Coastguard Worker memcpy(dest, src, m_TensorInfo.GetNumBytes());
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker
Import(void * memory,MemorySource source)115*89c4ff92SAndroid Build Coastguard Worker bool TosaRefTensorHandle::Import(void* memory, MemorySource source)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled && source == MemorySource::Malloc)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker // Check memory alignment
122*89c4ff92SAndroid Build Coastguard Worker if(!CanBeImported(memory, source))
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker if (m_Imported)
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker m_Imported = false;
127*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = nullptr;
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker return false;
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory not yet allocated.
133*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && !m_UnmanagedMemory)
134*89c4ff92SAndroid Build Coastguard Worker {
135*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = memory;
136*89c4ff92SAndroid Build Coastguard Worker m_Imported = true;
137*89c4ff92SAndroid Build Coastguard Worker return true;
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker
140*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory initially allocated with Allocate().
141*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && m_UnmanagedMemory)
142*89c4ff92SAndroid Build Coastguard Worker {
143*89c4ff92SAndroid Build Coastguard Worker return false;
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker // m_UnmanagedMemory previously imported.
147*89c4ff92SAndroid Build Coastguard Worker if (m_Imported)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker m_UnmanagedMemory = memory;
150*89c4ff92SAndroid Build Coastguard Worker return true;
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker
155*89c4ff92SAndroid Build Coastguard Worker return false;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker
CanBeImported(void * memory,MemorySource source)158*89c4ff92SAndroid Build Coastguard Worker bool TosaRefTensorHandle::CanBeImported(void* memory, MemorySource source)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled && source == MemorySource::Malloc)
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165*89c4ff92SAndroid Build Coastguard Worker if (reinterpret_cast<uintptr_t>(memory) % alignment)
166*89c4ff92SAndroid Build Coastguard Worker {
167*89c4ff92SAndroid Build Coastguard Worker return false;
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker return true;
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker return false;
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker }
176