xref: /aosp_15_r20/external/armnn/src/backends/neon/NeonTensorHandle.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <BFloat16.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorHandle.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/MemoryGroup.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/IMemoryGroup.h>
18*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/Tensor.h>
19*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/SubTensor.h>
20*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/TensorShape.h>
21*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Coordinates.h>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker namespace armnn
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker class NeonTensorHandle : public IAclTensorHandle
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker public:
NeonTensorHandle(const TensorInfo & tensorInfo)29*89c4ff92SAndroid Build Coastguard Worker     NeonTensorHandle(const TensorInfo& tensorInfo)
30*89c4ff92SAndroid Build Coastguard Worker                      : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
31*89c4ff92SAndroid Build Coastguard Worker                        m_Imported(false),
32*89c4ff92SAndroid Build Coastguard Worker                        m_IsImportEnabled(false),
33*89c4ff92SAndroid Build Coastguard Worker                        m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
36*89c4ff92SAndroid Build Coastguard Worker     }
37*89c4ff92SAndroid Build Coastguard Worker 
NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))38*89c4ff92SAndroid Build Coastguard Worker     NeonTensorHandle(const TensorInfo& tensorInfo,
39*89c4ff92SAndroid Build Coastguard Worker                      DataLayout dataLayout,
40*89c4ff92SAndroid Build Coastguard Worker                      MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc))
41*89c4ff92SAndroid Build Coastguard Worker                      : m_ImportFlags(importFlags),
42*89c4ff92SAndroid Build Coastguard Worker                        m_Imported(false),
43*89c4ff92SAndroid Build Coastguard Worker                        m_IsImportEnabled(false),
44*89c4ff92SAndroid Build Coastguard Worker                        m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType()))
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     {
48*89c4ff92SAndroid Build Coastguard Worker         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
49*89c4ff92SAndroid Build Coastguard Worker     }
50*89c4ff92SAndroid Build Coastguard Worker 
GetTensor()51*89c4ff92SAndroid Build Coastguard Worker     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const52*89c4ff92SAndroid Build Coastguard Worker     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
53*89c4ff92SAndroid Build Coastguard Worker 
Allocate()54*89c4ff92SAndroid Build Coastguard Worker     virtual void Allocate() override
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         // If we have enabled Importing, don't Allocate the tensor
57*89c4ff92SAndroid Build Coastguard Worker         if (!m_IsImportEnabled)
58*89c4ff92SAndroid Build Coastguard Worker         {
59*89c4ff92SAndroid Build Coastguard Worker             armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
60*89c4ff92SAndroid Build Coastguard Worker         }
61*89c4ff92SAndroid Build Coastguard Worker     };
62*89c4ff92SAndroid Build Coastguard Worker 
Manage()63*89c4ff92SAndroid Build Coastguard Worker     virtual void Manage() override
64*89c4ff92SAndroid Build Coastguard Worker     {
65*89c4ff92SAndroid Build Coastguard Worker         // If we have enabled Importing, don't manage the tensor
66*89c4ff92SAndroid Build Coastguard Worker         if (!m_IsImportEnabled)
67*89c4ff92SAndroid Build Coastguard Worker         {
68*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(m_MemoryGroup != nullptr);
69*89c4ff92SAndroid Build Coastguard Worker             m_MemoryGroup->manage(&m_Tensor);
70*89c4ff92SAndroid Build Coastguard Worker         }
71*89c4ff92SAndroid Build Coastguard Worker     }
72*89c4ff92SAndroid Build Coastguard Worker 
GetParent() const73*89c4ff92SAndroid Build Coastguard Worker     virtual ITensorHandle* GetParent() const override { return nullptr; }
74*89c4ff92SAndroid Build Coastguard Worker 
GetDataType() const75*89c4ff92SAndroid Build Coastguard Worker     virtual arm_compute::DataType GetDataType() const override
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         return m_Tensor.info()->data_type();
78*89c4ff92SAndroid Build Coastguard Worker     }
79*89c4ff92SAndroid Build Coastguard Worker 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)80*89c4ff92SAndroid Build Coastguard Worker     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
81*89c4ff92SAndroid Build Coastguard Worker     {
82*89c4ff92SAndroid Build Coastguard Worker         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
83*89c4ff92SAndroid Build Coastguard Worker     }
84*89c4ff92SAndroid Build Coastguard Worker 
Map(bool) const85*89c4ff92SAndroid Build Coastguard Worker     virtual const void* Map(bool /* blocking = true */) const override
86*89c4ff92SAndroid Build Coastguard Worker     {
87*89c4ff92SAndroid Build Coastguard Worker         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker 
Unmap() const90*89c4ff92SAndroid Build Coastguard Worker     virtual void Unmap() const override {}
91*89c4ff92SAndroid Build Coastguard Worker 
GetStrides() const92*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetStrides() const override
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker 
GetShape() const97*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetShape() const override
98*89c4ff92SAndroid Build Coastguard Worker     {
99*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
100*89c4ff92SAndroid Build Coastguard Worker     }
101*89c4ff92SAndroid Build Coastguard Worker 
SetImportFlags(MemorySourceFlags importFlags)102*89c4ff92SAndroid Build Coastguard Worker     void SetImportFlags(MemorySourceFlags importFlags)
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker         m_ImportFlags = importFlags;
105*89c4ff92SAndroid Build Coastguard Worker     }
106*89c4ff92SAndroid Build Coastguard Worker 
GetImportFlags() const107*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags GetImportFlags() const override
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         return m_ImportFlags;
110*89c4ff92SAndroid Build Coastguard Worker     }
111*89c4ff92SAndroid Build Coastguard Worker 
SetImportEnabledFlag(bool importEnabledFlag)112*89c4ff92SAndroid Build Coastguard Worker     void SetImportEnabledFlag(bool importEnabledFlag)
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         m_IsImportEnabled = importEnabledFlag;
115*89c4ff92SAndroid Build Coastguard Worker     }
116*89c4ff92SAndroid Build Coastguard Worker 
CanBeImported(void * memory,MemorySource source)117*89c4ff92SAndroid Build Coastguard Worker     bool CanBeImported(void* memory, MemorySource source) override
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment)
120*89c4ff92SAndroid Build Coastguard Worker         {
121*89c4ff92SAndroid Build Coastguard Worker             return false;
122*89c4ff92SAndroid Build Coastguard Worker         }
123*89c4ff92SAndroid Build Coastguard Worker         return true;
124*89c4ff92SAndroid Build Coastguard Worker     }
125*89c4ff92SAndroid Build Coastguard Worker 
Import(void * memory,MemorySource source)126*89c4ff92SAndroid Build Coastguard Worker     virtual bool Import(void* memory, MemorySource source) override
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
129*89c4ff92SAndroid Build Coastguard Worker         {
130*89c4ff92SAndroid Build Coastguard Worker             if (source == MemorySource::Malloc && m_IsImportEnabled)
131*89c4ff92SAndroid Build Coastguard Worker             {
132*89c4ff92SAndroid Build Coastguard Worker                 if (!CanBeImported(memory, source))
133*89c4ff92SAndroid Build Coastguard Worker                 {
134*89c4ff92SAndroid Build Coastguard Worker                     throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory");
135*89c4ff92SAndroid Build Coastguard Worker                 }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker                 // m_Tensor not yet Allocated
138*89c4ff92SAndroid Build Coastguard Worker                 if (!m_Imported && !m_Tensor.buffer())
139*89c4ff92SAndroid Build Coastguard Worker                 {
140*89c4ff92SAndroid Build Coastguard Worker                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
141*89c4ff92SAndroid Build Coastguard Worker                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
142*89c4ff92SAndroid Build Coastguard Worker                     // with the Status error message
143*89c4ff92SAndroid Build Coastguard Worker                     m_Imported = bool(status);
144*89c4ff92SAndroid Build Coastguard Worker                     if (!m_Imported)
145*89c4ff92SAndroid Build Coastguard Worker                     {
146*89c4ff92SAndroid Build Coastguard Worker                         throw MemoryImportException(status.error_description());
147*89c4ff92SAndroid Build Coastguard Worker                     }
148*89c4ff92SAndroid Build Coastguard Worker                     return m_Imported;
149*89c4ff92SAndroid Build Coastguard Worker                 }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker                 // m_Tensor.buffer() initially allocated with Allocate().
152*89c4ff92SAndroid Build Coastguard Worker                 if (!m_Imported && m_Tensor.buffer())
153*89c4ff92SAndroid Build Coastguard Worker                 {
154*89c4ff92SAndroid Build Coastguard Worker                     throw MemoryImportException(
155*89c4ff92SAndroid Build Coastguard Worker                         "NeonTensorHandle::Import Attempting to import on an already allocated tensor");
156*89c4ff92SAndroid Build Coastguard Worker                 }
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker                 // m_Tensor.buffer() previously imported.
159*89c4ff92SAndroid Build Coastguard Worker                 if (m_Imported)
160*89c4ff92SAndroid Build Coastguard Worker                 {
161*89c4ff92SAndroid Build Coastguard Worker                     arm_compute::Status status = m_Tensor.allocator()->import_memory(memory);
162*89c4ff92SAndroid Build Coastguard Worker                     // Use the overloaded bool operator of Status to check if it worked, if not throw an exception
163*89c4ff92SAndroid Build Coastguard Worker                     // with the Status error message
164*89c4ff92SAndroid Build Coastguard Worker                     m_Imported = bool(status);
165*89c4ff92SAndroid Build Coastguard Worker                     if (!m_Imported)
166*89c4ff92SAndroid Build Coastguard Worker                     {
167*89c4ff92SAndroid Build Coastguard Worker                         throw MemoryImportException(status.error_description());
168*89c4ff92SAndroid Build Coastguard Worker                     }
169*89c4ff92SAndroid Build Coastguard Worker                     return m_Imported;
170*89c4ff92SAndroid Build Coastguard Worker                 }
171*89c4ff92SAndroid Build Coastguard Worker             }
172*89c4ff92SAndroid Build Coastguard Worker             else
173*89c4ff92SAndroid Build Coastguard Worker             {
174*89c4ff92SAndroid Build Coastguard Worker                 throw MemoryImportException("NeonTensorHandle::Import is disabled");
175*89c4ff92SAndroid Build Coastguard Worker             }
176*89c4ff92SAndroid Build Coastguard Worker         }
177*89c4ff92SAndroid Build Coastguard Worker         else
178*89c4ff92SAndroid Build Coastguard Worker         {
179*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("NeonTensorHandle::Incorrect import flag");
180*89c4ff92SAndroid Build Coastguard Worker         }
181*89c4ff92SAndroid Build Coastguard Worker         return false;
182*89c4ff92SAndroid Build Coastguard Worker     }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker private:
185*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyOutTo(void * memory) const186*89c4ff92SAndroid Build Coastguard Worker     void CopyOutTo(void* memory) const override
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         switch (this->GetDataType())
189*89c4ff92SAndroid Build Coastguard Worker         {
190*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
191*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
192*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<float*>(memory));
193*89c4ff92SAndroid Build Coastguard Worker                 break;
194*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
195*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
196*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
197*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<uint8_t*>(memory));
198*89c4ff92SAndroid Build Coastguard Worker                 break;
199*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
200*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
201*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
202*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int8_t*>(memory));
203*89c4ff92SAndroid Build Coastguard Worker                 break;
204*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::BFLOAT16:
205*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
206*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<armnn::BFloat16*>(memory));
207*89c4ff92SAndroid Build Coastguard Worker                 break;
208*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
209*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
210*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<armnn::Half*>(memory));
211*89c4ff92SAndroid Build Coastguard Worker                 break;
212*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
213*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
214*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
215*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int16_t*>(memory));
216*89c4ff92SAndroid Build Coastguard Worker                 break;
217*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
218*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
219*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int32_t*>(memory));
220*89c4ff92SAndroid Build Coastguard Worker                 break;
221*89c4ff92SAndroid Build Coastguard Worker             default:
222*89c4ff92SAndroid Build Coastguard Worker             {
223*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
224*89c4ff92SAndroid Build Coastguard Worker             }
225*89c4ff92SAndroid Build Coastguard Worker         }
226*89c4ff92SAndroid Build Coastguard Worker     }
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyInFrom(const void * memory)229*89c4ff92SAndroid Build Coastguard Worker     void CopyInFrom(const void* memory) override
230*89c4ff92SAndroid Build Coastguard Worker     {
231*89c4ff92SAndroid Build Coastguard Worker         switch (this->GetDataType())
232*89c4ff92SAndroid Build Coastguard Worker         {
233*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
234*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
235*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
236*89c4ff92SAndroid Build Coastguard Worker                 break;
237*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
238*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
239*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
240*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
241*89c4ff92SAndroid Build Coastguard Worker                 break;
242*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
243*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
244*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
245*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
246*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
247*89c4ff92SAndroid Build Coastguard Worker                 break;
248*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::BFLOAT16:
249*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory),
250*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
251*89c4ff92SAndroid Build Coastguard Worker                 break;
252*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
253*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
254*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
255*89c4ff92SAndroid Build Coastguard Worker                 break;
256*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
257*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
258*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
259*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
260*89c4ff92SAndroid Build Coastguard Worker                 break;
261*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
262*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
263*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
264*89c4ff92SAndroid Build Coastguard Worker                 break;
265*89c4ff92SAndroid Build Coastguard Worker             default:
266*89c4ff92SAndroid Build Coastguard Worker             {
267*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
268*89c4ff92SAndroid Build Coastguard Worker             }
269*89c4ff92SAndroid Build Coastguard Worker         }
270*89c4ff92SAndroid Build Coastguard Worker     }
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     arm_compute::Tensor m_Tensor;
273*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
274*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags m_ImportFlags;
275*89c4ff92SAndroid Build Coastguard Worker     bool m_Imported;
276*89c4ff92SAndroid Build Coastguard Worker     bool m_IsImportEnabled;
277*89c4ff92SAndroid Build Coastguard Worker     const uintptr_t m_TypeAlignment;
278*89c4ff92SAndroid Build Coastguard Worker };
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker class NeonSubTensorHandle : public IAclTensorHandle
281*89c4ff92SAndroid Build Coastguard Worker {
282*89c4ff92SAndroid Build Coastguard Worker public:
NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)283*89c4ff92SAndroid Build Coastguard Worker     NeonSubTensorHandle(IAclTensorHandle* parent,
284*89c4ff92SAndroid Build Coastguard Worker                         const arm_compute::TensorShape& shape,
285*89c4ff92SAndroid Build Coastguard Worker                         const arm_compute::Coordinates& coords)
286*89c4ff92SAndroid Build Coastguard Worker      : m_Tensor(&parent->GetTensor(), shape, coords)
287*89c4ff92SAndroid Build Coastguard Worker     {
288*89c4ff92SAndroid Build Coastguard Worker         parentHandle = parent;
289*89c4ff92SAndroid Build Coastguard Worker     }
290*89c4ff92SAndroid Build Coastguard Worker 
GetTensor()291*89c4ff92SAndroid Build Coastguard Worker     arm_compute::ITensor& GetTensor() override { return m_Tensor; }
GetTensor() const292*89c4ff92SAndroid Build Coastguard Worker     arm_compute::ITensor const& GetTensor() const override { return m_Tensor; }
293*89c4ff92SAndroid Build Coastguard Worker 
Allocate()294*89c4ff92SAndroid Build Coastguard Worker     virtual void Allocate() override {}
Manage()295*89c4ff92SAndroid Build Coastguard Worker     virtual void Manage() override {}
296*89c4ff92SAndroid Build Coastguard Worker 
GetParent() const297*89c4ff92SAndroid Build Coastguard Worker     virtual ITensorHandle* GetParent() const override { return parentHandle; }
298*89c4ff92SAndroid Build Coastguard Worker 
GetDataType() const299*89c4ff92SAndroid Build Coastguard Worker     virtual arm_compute::DataType GetDataType() const override
300*89c4ff92SAndroid Build Coastguard Worker     {
301*89c4ff92SAndroid Build Coastguard Worker         return m_Tensor.info()->data_type();
302*89c4ff92SAndroid Build Coastguard Worker     }
303*89c4ff92SAndroid Build Coastguard Worker 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)304*89c4ff92SAndroid Build Coastguard Worker     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
305*89c4ff92SAndroid Build Coastguard Worker 
Map(bool) const306*89c4ff92SAndroid Build Coastguard Worker     virtual const void* Map(bool /* blocking = true */) const override
307*89c4ff92SAndroid Build Coastguard Worker     {
308*89c4ff92SAndroid Build Coastguard Worker         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
309*89c4ff92SAndroid Build Coastguard Worker     }
Unmap() const310*89c4ff92SAndroid Build Coastguard Worker     virtual void Unmap() const override {}
311*89c4ff92SAndroid Build Coastguard Worker 
GetStrides() const312*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetStrides() const override
313*89c4ff92SAndroid Build Coastguard Worker     {
314*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
315*89c4ff92SAndroid Build Coastguard Worker     }
316*89c4ff92SAndroid Build Coastguard Worker 
GetShape() const317*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetShape() const override
318*89c4ff92SAndroid Build Coastguard Worker     {
319*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
320*89c4ff92SAndroid Build Coastguard Worker     }
321*89c4ff92SAndroid Build Coastguard Worker 
322*89c4ff92SAndroid Build Coastguard Worker private:
323*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyOutTo(void * memory) const324*89c4ff92SAndroid Build Coastguard Worker     void CopyOutTo(void* memory) const override
325*89c4ff92SAndroid Build Coastguard Worker     {
326*89c4ff92SAndroid Build Coastguard Worker         switch (this->GetDataType())
327*89c4ff92SAndroid Build Coastguard Worker         {
328*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
329*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
330*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<float*>(memory));
331*89c4ff92SAndroid Build Coastguard Worker                 break;
332*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
333*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
334*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
335*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<uint8_t*>(memory));
336*89c4ff92SAndroid Build Coastguard Worker                 break;
337*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
338*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
339*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
340*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int8_t*>(memory));
341*89c4ff92SAndroid Build Coastguard Worker                 break;
342*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
343*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
344*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
345*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int16_t*>(memory));
346*89c4ff92SAndroid Build Coastguard Worker                 break;
347*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
348*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
349*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int32_t*>(memory));
350*89c4ff92SAndroid Build Coastguard Worker                 break;
351*89c4ff92SAndroid Build Coastguard Worker             default:
352*89c4ff92SAndroid Build Coastguard Worker             {
353*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
354*89c4ff92SAndroid Build Coastguard Worker             }
355*89c4ff92SAndroid Build Coastguard Worker         }
356*89c4ff92SAndroid Build Coastguard Worker     }
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyInFrom(const void * memory)359*89c4ff92SAndroid Build Coastguard Worker     void CopyInFrom(const void* memory) override
360*89c4ff92SAndroid Build Coastguard Worker     {
361*89c4ff92SAndroid Build Coastguard Worker         switch (this->GetDataType())
362*89c4ff92SAndroid Build Coastguard Worker         {
363*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
364*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
365*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
366*89c4ff92SAndroid Build Coastguard Worker                 break;
367*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
368*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
369*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
370*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
371*89c4ff92SAndroid Build Coastguard Worker                 break;
372*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
373*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
374*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
375*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
376*89c4ff92SAndroid Build Coastguard Worker                 break;
377*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
378*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
379*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
380*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
381*89c4ff92SAndroid Build Coastguard Worker                 break;
382*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
383*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
384*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
385*89c4ff92SAndroid Build Coastguard Worker                 break;
386*89c4ff92SAndroid Build Coastguard Worker             default:
387*89c4ff92SAndroid Build Coastguard Worker             {
388*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
389*89c4ff92SAndroid Build Coastguard Worker             }
390*89c4ff92SAndroid Build Coastguard Worker         }
391*89c4ff92SAndroid Build Coastguard Worker     }
392*89c4ff92SAndroid Build Coastguard Worker 
393*89c4ff92SAndroid Build Coastguard Worker     arm_compute::SubTensor m_Tensor;
394*89c4ff92SAndroid Build Coastguard Worker     ITensorHandle* parentHandle = nullptr;
395*89c4ff92SAndroid Build Coastguard Worker };
396*89c4ff92SAndroid Build Coastguard Worker 
397*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
398