xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/TosaRefTensorHandle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "TosaRefTensorHandle.hpp"
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker namespace armnn
8*89c4ff92SAndroid Build Coastguard Worker {
9*89c4ff92SAndroid Build Coastguard Worker 
TosaRefTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<TosaRefMemoryManager> & memoryManager)10*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
11*89c4ff92SAndroid Build Coastguard Worker                                          std::shared_ptr<TosaRefMemoryManager>& memoryManager)
12*89c4ff92SAndroid Build Coastguard Worker      : m_TensorInfo(tensorInfo)
13*89c4ff92SAndroid Build Coastguard Worker      , m_MemoryManager(memoryManager)
14*89c4ff92SAndroid Build Coastguard Worker      , m_Pool(nullptr)
15*89c4ff92SAndroid Build Coastguard Worker      , m_UnmanagedMemory(nullptr)
16*89c4ff92SAndroid Build Coastguard Worker      , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17*89c4ff92SAndroid Build Coastguard Worker      , m_Imported(false)
18*89c4ff92SAndroid Build Coastguard Worker      , m_IsImportEnabled(false)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker 
TosaRefTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::TosaRefTensorHandle(const TensorInfo& tensorInfo,
22*89c4ff92SAndroid Build Coastguard Worker                                          MemorySourceFlags importFlags)
23*89c4ff92SAndroid Build Coastguard Worker     : m_TensorInfo(tensorInfo)
24*89c4ff92SAndroid Build Coastguard Worker     , m_Pool(nullptr)
25*89c4ff92SAndroid Build Coastguard Worker     , m_UnmanagedMemory(nullptr)
26*89c4ff92SAndroid Build Coastguard Worker     , m_ImportFlags(importFlags)
27*89c4ff92SAndroid Build Coastguard Worker     , m_Imported(false)
28*89c4ff92SAndroid Build Coastguard Worker     , m_IsImportEnabled(true)
29*89c4ff92SAndroid Build Coastguard Worker {}
30*89c4ff92SAndroid Build Coastguard Worker 
~TosaRefTensorHandle()31*89c4ff92SAndroid Build Coastguard Worker TosaRefTensorHandle::~TosaRefTensorHandle()
32*89c4ff92SAndroid Build Coastguard Worker {
33*89c4ff92SAndroid Build Coastguard Worker     if (!m_Pool)
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         // unmanaged
36*89c4ff92SAndroid Build Coastguard Worker         if (!m_Imported)
37*89c4ff92SAndroid Build Coastguard Worker         {
38*89c4ff92SAndroid Build Coastguard Worker             ::operator delete(m_UnmanagedMemory);
39*89c4ff92SAndroid Build Coastguard Worker         }
40*89c4ff92SAndroid Build Coastguard Worker     }
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker 
Manage()43*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::Manage()
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker     if (!m_IsImportEnabled)
46*89c4ff92SAndroid Build Coastguard Worker     {
47*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(!m_Pool, "TosaRefTensorHandle::Manage() called twice");
48*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "TosaRefTensorHandle::Manage() called after Allocate()");
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
51*89c4ff92SAndroid Build Coastguard Worker     }
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker 
Allocate()54*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::Allocate()
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker     // If import is enabled, do not allocate the tensor
57*89c4ff92SAndroid Build Coastguard Worker     if (!m_IsImportEnabled)
58*89c4ff92SAndroid Build Coastguard Worker     {
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker         if (!m_UnmanagedMemory)
61*89c4ff92SAndroid Build Coastguard Worker         {
62*89c4ff92SAndroid Build Coastguard Worker             if (!m_Pool)
63*89c4ff92SAndroid Build Coastguard Worker             {
64*89c4ff92SAndroid Build Coastguard Worker                 // unmanaged
65*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
66*89c4ff92SAndroid Build Coastguard Worker             }
67*89c4ff92SAndroid Build Coastguard Worker             else
68*89c4ff92SAndroid Build Coastguard Worker             {
69*89c4ff92SAndroid Build Coastguard Worker                 m_MemoryManager->Allocate(m_Pool);
70*89c4ff92SAndroid Build Coastguard Worker             }
71*89c4ff92SAndroid Build Coastguard Worker         }
72*89c4ff92SAndroid Build Coastguard Worker         else
73*89c4ff92SAndroid Build Coastguard Worker         {
74*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("TosaRefTensorHandle::Allocate Trying to allocate a TosaRefTensorHandle"
75*89c4ff92SAndroid Build Coastguard Worker                                            "that already has allocated memory.");
76*89c4ff92SAndroid Build Coastguard Worker         }
77*89c4ff92SAndroid Build Coastguard Worker     }
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker 
Map(bool) const80*89c4ff92SAndroid Build Coastguard Worker const void* TosaRefTensorHandle::Map(bool /*unused*/) const
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker     return GetPointer();
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker 
GetPointer() const85*89c4ff92SAndroid Build Coastguard Worker void* TosaRefTensorHandle::GetPointer() const
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     if (m_UnmanagedMemory)
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         return m_UnmanagedMemory;
90*89c4ff92SAndroid Build Coastguard Worker     }
91*89c4ff92SAndroid Build Coastguard Worker     else if (m_Pool)
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         return m_MemoryManager->GetPointer(m_Pool);
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker     else
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         throw NullPointerException("TosaRefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker 
CopyOutTo(void * dest) const101*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::CopyOutTo(void* dest) const
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker     const void *src = GetPointer();
104*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(src);
105*89c4ff92SAndroid Build Coastguard Worker     memcpy(dest, src, m_TensorInfo.GetNumBytes());
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
CopyInFrom(const void * src)108*89c4ff92SAndroid Build Coastguard Worker void TosaRefTensorHandle::CopyInFrom(const void* src)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     void *dest = GetPointer();
111*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(dest);
112*89c4ff92SAndroid Build Coastguard Worker     memcpy(dest, src, m_TensorInfo.GetNumBytes());
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
Import(void * memory,MemorySource source)115*89c4ff92SAndroid Build Coastguard Worker bool TosaRefTensorHandle::Import(void* memory, MemorySource source)
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled && source == MemorySource::Malloc)
120*89c4ff92SAndroid Build Coastguard Worker         {
121*89c4ff92SAndroid Build Coastguard Worker             // Check memory alignment
122*89c4ff92SAndroid Build Coastguard Worker             if(!CanBeImported(memory, source))
123*89c4ff92SAndroid Build Coastguard Worker             {
124*89c4ff92SAndroid Build Coastguard Worker                 if (m_Imported)
125*89c4ff92SAndroid Build Coastguard Worker                 {
126*89c4ff92SAndroid Build Coastguard Worker                     m_Imported = false;
127*89c4ff92SAndroid Build Coastguard Worker                     m_UnmanagedMemory = nullptr;
128*89c4ff92SAndroid Build Coastguard Worker                 }
129*89c4ff92SAndroid Build Coastguard Worker                 return false;
130*89c4ff92SAndroid Build Coastguard Worker             }
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory not yet allocated.
133*89c4ff92SAndroid Build Coastguard Worker             if (!m_Imported && !m_UnmanagedMemory)
134*89c4ff92SAndroid Build Coastguard Worker             {
135*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = memory;
136*89c4ff92SAndroid Build Coastguard Worker                 m_Imported = true;
137*89c4ff92SAndroid Build Coastguard Worker                 return true;
138*89c4ff92SAndroid Build Coastguard Worker             }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory initially allocated with Allocate().
141*89c4ff92SAndroid Build Coastguard Worker             if (!m_Imported && m_UnmanagedMemory)
142*89c4ff92SAndroid Build Coastguard Worker             {
143*89c4ff92SAndroid Build Coastguard Worker                 return false;
144*89c4ff92SAndroid Build Coastguard Worker             }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker             // m_UnmanagedMemory previously imported.
147*89c4ff92SAndroid Build Coastguard Worker             if (m_Imported)
148*89c4ff92SAndroid Build Coastguard Worker             {
149*89c4ff92SAndroid Build Coastguard Worker                 m_UnmanagedMemory = memory;
150*89c4ff92SAndroid Build Coastguard Worker                 return true;
151*89c4ff92SAndroid Build Coastguard Worker             }
152*89c4ff92SAndroid Build Coastguard Worker         }
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     return false;
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker 
CanBeImported(void * memory,MemorySource source)158*89c4ff92SAndroid Build Coastguard Worker bool TosaRefTensorHandle::CanBeImported(void* memory, MemorySource source)
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled && source == MemorySource::Malloc)
163*89c4ff92SAndroid Build Coastguard Worker         {
164*89c4ff92SAndroid Build Coastguard Worker             uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165*89c4ff92SAndroid Build Coastguard Worker             if (reinterpret_cast<uintptr_t>(memory) % alignment)
166*89c4ff92SAndroid Build Coastguard Worker             {
167*89c4ff92SAndroid Build Coastguard Worker                 return false;
168*89c4ff92SAndroid Build Coastguard Worker             }
169*89c4ff92SAndroid Build Coastguard Worker             return true;
170*89c4ff92SAndroid Build Coastguard Worker         }
171*89c4ff92SAndroid Build Coastguard Worker     }
172*89c4ff92SAndroid Build Coastguard Worker     return false;
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker }
176