xref: /aosp_15_r20/external/armnn/src/backends/cl/ClTensorHandle.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorHandle.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLTensor.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLSubTensor.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/IMemoryGroup.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/MemoryGroup.h>
18*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/TensorShape.h>
19*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Coordinates.h>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/IClTensorHandle.hpp>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker namespace armnn
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker class ClTensorHandle : public IClTensorHandle
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker public:
ClTensorHandle(const TensorInfo & tensorInfo)29*89c4ff92SAndroid Build Coastguard Worker     ClTensorHandle(const TensorInfo& tensorInfo)
30*89c4ff92SAndroid Build Coastguard Worker                      : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)),
31*89c4ff92SAndroid Build Coastguard Worker                        m_Imported(false),
32*89c4ff92SAndroid Build Coastguard Worker                        m_IsImportEnabled(false)
33*89c4ff92SAndroid Build Coastguard Worker     {
34*89c4ff92SAndroid Build Coastguard Worker         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo);
35*89c4ff92SAndroid Build Coastguard Worker     }
36*89c4ff92SAndroid Build Coastguard Worker 
ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Undefined))37*89c4ff92SAndroid Build Coastguard Worker     ClTensorHandle(const TensorInfo& tensorInfo,
38*89c4ff92SAndroid Build Coastguard Worker                    DataLayout dataLayout,
39*89c4ff92SAndroid Build Coastguard Worker                    MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined))
40*89c4ff92SAndroid Build Coastguard Worker                    : m_ImportFlags(importFlags),
41*89c4ff92SAndroid Build Coastguard Worker                      m_Imported(false),
42*89c4ff92SAndroid Build Coastguard Worker                      m_IsImportEnabled(false)
43*89c4ff92SAndroid Build Coastguard Worker     {
44*89c4ff92SAndroid Build Coastguard Worker         armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout);
45*89c4ff92SAndroid Build Coastguard Worker     }
46*89c4ff92SAndroid Build Coastguard Worker 
GetTensor()47*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLTensor& GetTensor() override { return m_Tensor; }
GetTensor() const48*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; }
Allocate()49*89c4ff92SAndroid Build Coastguard Worker     virtual void Allocate() override
50*89c4ff92SAndroid Build Coastguard Worker     {
51*89c4ff92SAndroid Build Coastguard Worker         // If we have enabled Importing, don't allocate the tensor
52*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled)
53*89c4ff92SAndroid Build Coastguard Worker         {
54*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing");
55*89c4ff92SAndroid Build Coastguard Worker         }
56*89c4ff92SAndroid Build Coastguard Worker         else
57*89c4ff92SAndroid Build Coastguard Worker         {
58*89c4ff92SAndroid Build Coastguard Worker             armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor);
59*89c4ff92SAndroid Build Coastguard Worker         }
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker 
Manage()63*89c4ff92SAndroid Build Coastguard Worker     virtual void Manage() override
64*89c4ff92SAndroid Build Coastguard Worker     {
65*89c4ff92SAndroid Build Coastguard Worker         // If we have enabled Importing, don't manage the tensor
66*89c4ff92SAndroid Build Coastguard Worker         if (m_IsImportEnabled)
67*89c4ff92SAndroid Build Coastguard Worker         {
68*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing");
69*89c4ff92SAndroid Build Coastguard Worker         }
70*89c4ff92SAndroid Build Coastguard Worker         else
71*89c4ff92SAndroid Build Coastguard Worker         {
72*89c4ff92SAndroid Build Coastguard Worker             assert(m_MemoryGroup != nullptr);
73*89c4ff92SAndroid Build Coastguard Worker             m_MemoryGroup->manage(&m_Tensor);
74*89c4ff92SAndroid Build Coastguard Worker         }
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker 
Map(bool blocking=true) const77*89c4ff92SAndroid Build Coastguard Worker     virtual const void* Map(bool blocking = true) const override
78*89c4ff92SAndroid Build Coastguard Worker     {
79*89c4ff92SAndroid Build Coastguard Worker         const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking);
80*89c4ff92SAndroid Build Coastguard Worker         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
81*89c4ff92SAndroid Build Coastguard Worker     }
82*89c4ff92SAndroid Build Coastguard Worker 
Unmap() const83*89c4ff92SAndroid Build Coastguard Worker     virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); }
84*89c4ff92SAndroid Build Coastguard Worker 
GetParent() const85*89c4ff92SAndroid Build Coastguard Worker     virtual ITensorHandle* GetParent() const override { return nullptr; }
86*89c4ff92SAndroid Build Coastguard Worker 
GetDataType() const87*89c4ff92SAndroid Build Coastguard Worker     virtual arm_compute::DataType GetDataType() const override
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         return m_Tensor.info()->data_type();
90*89c4ff92SAndroid Build Coastguard Worker     }
91*89c4ff92SAndroid Build Coastguard Worker 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)92*89c4ff92SAndroid Build Coastguard Worker     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override
93*89c4ff92SAndroid Build Coastguard Worker     {
94*89c4ff92SAndroid Build Coastguard Worker         m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup);
95*89c4ff92SAndroid Build Coastguard Worker     }
96*89c4ff92SAndroid Build Coastguard Worker 
GetStrides() const97*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetStrides() const override
98*89c4ff92SAndroid Build Coastguard Worker     {
99*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
100*89c4ff92SAndroid Build Coastguard Worker     }
101*89c4ff92SAndroid Build Coastguard Worker 
GetShape() const102*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetShape() const override
103*89c4ff92SAndroid Build Coastguard Worker     {
104*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
105*89c4ff92SAndroid Build Coastguard Worker     }
106*89c4ff92SAndroid Build Coastguard Worker 
SetImportFlags(MemorySourceFlags importFlags)107*89c4ff92SAndroid Build Coastguard Worker     void SetImportFlags(MemorySourceFlags importFlags)
108*89c4ff92SAndroid Build Coastguard Worker     {
109*89c4ff92SAndroid Build Coastguard Worker         m_ImportFlags = importFlags;
110*89c4ff92SAndroid Build Coastguard Worker     }
111*89c4ff92SAndroid Build Coastguard Worker 
GetImportFlags() const112*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags GetImportFlags() const override
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         return m_ImportFlags;
115*89c4ff92SAndroid Build Coastguard Worker     }
116*89c4ff92SAndroid Build Coastguard Worker 
SetImportEnabledFlag(bool importEnabledFlag)117*89c4ff92SAndroid Build Coastguard Worker     void SetImportEnabledFlag(bool importEnabledFlag)
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         m_IsImportEnabled = importEnabledFlag;
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker 
Import(void * memory,MemorySource source)122*89c4ff92SAndroid Build Coastguard Worker     virtual bool Import(void* memory, MemorySource source) override
123*89c4ff92SAndroid Build Coastguard Worker     {
124*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(memory);
125*89c4ff92SAndroid Build Coastguard Worker         if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
126*89c4ff92SAndroid Build Coastguard Worker         {
127*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ClTensorHandle::Incorrect import flag");
128*89c4ff92SAndroid Build Coastguard Worker         }
129*89c4ff92SAndroid Build Coastguard Worker         m_Imported = false;
130*89c4ff92SAndroid Build Coastguard Worker         return false;
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
CanBeImported(void * memory,MemorySource source)133*89c4ff92SAndroid Build Coastguard Worker     virtual bool CanBeImported(void* memory, MemorySource source) override
134*89c4ff92SAndroid Build Coastguard Worker     {
135*89c4ff92SAndroid Build Coastguard Worker         // This TensorHandle can never import.
136*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(memory, source);
137*89c4ff92SAndroid Build Coastguard Worker         return false;
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker private:
141*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyOutTo(void * memory) const142*89c4ff92SAndroid Build Coastguard Worker     void CopyOutTo(void* memory) const override
143*89c4ff92SAndroid Build Coastguard Worker     {
144*89c4ff92SAndroid Build Coastguard Worker         const_cast<armnn::ClTensorHandle*>(this)->Map(true);
145*89c4ff92SAndroid Build Coastguard Worker         switch(this->GetDataType())
146*89c4ff92SAndroid Build Coastguard Worker         {
147*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
148*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
149*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<float*>(memory));
150*89c4ff92SAndroid Build Coastguard Worker                 break;
151*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
152*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
153*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
154*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<uint8_t*>(memory));
155*89c4ff92SAndroid Build Coastguard Worker                 break;
156*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
157*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
158*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
159*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
160*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int8_t*>(memory));
161*89c4ff92SAndroid Build Coastguard Worker                 break;
162*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
163*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
164*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<armnn::Half*>(memory));
165*89c4ff92SAndroid Build Coastguard Worker                 break;
166*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
167*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
168*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
169*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int16_t*>(memory));
170*89c4ff92SAndroid Build Coastguard Worker                 break;
171*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
172*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
173*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int32_t*>(memory));
174*89c4ff92SAndroid Build Coastguard Worker                 break;
175*89c4ff92SAndroid Build Coastguard Worker             default:
176*89c4ff92SAndroid Build Coastguard Worker             {
177*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
178*89c4ff92SAndroid Build Coastguard Worker             }
179*89c4ff92SAndroid Build Coastguard Worker         }
180*89c4ff92SAndroid Build Coastguard Worker         const_cast<armnn::ClTensorHandle*>(this)->Unmap();
181*89c4ff92SAndroid Build Coastguard Worker     }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyInFrom(const void * memory)184*89c4ff92SAndroid Build Coastguard Worker     void CopyInFrom(const void* memory) override
185*89c4ff92SAndroid Build Coastguard Worker     {
186*89c4ff92SAndroid Build Coastguard Worker         this->Map(true);
187*89c4ff92SAndroid Build Coastguard Worker         switch(this->GetDataType())
188*89c4ff92SAndroid Build Coastguard Worker         {
189*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
190*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
191*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
192*89c4ff92SAndroid Build Coastguard Worker                 break;
193*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
194*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
195*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
196*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
197*89c4ff92SAndroid Build Coastguard Worker                 break;
198*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
199*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
200*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
201*89c4ff92SAndroid Build Coastguard Worker                 break;
202*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
203*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
204*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
205*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
206*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
207*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
208*89c4ff92SAndroid Build Coastguard Worker                 break;
209*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
210*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
211*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
212*89c4ff92SAndroid Build Coastguard Worker                 break;
213*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
214*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
215*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
216*89c4ff92SAndroid Build Coastguard Worker                 break;
217*89c4ff92SAndroid Build Coastguard Worker             default:
218*89c4ff92SAndroid Build Coastguard Worker             {
219*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
220*89c4ff92SAndroid Build Coastguard Worker             }
221*89c4ff92SAndroid Build Coastguard Worker         }
222*89c4ff92SAndroid Build Coastguard Worker         this->Unmap();
223*89c4ff92SAndroid Build Coastguard Worker     }
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLTensor m_Tensor;
226*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup;
227*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags m_ImportFlags;
228*89c4ff92SAndroid Build Coastguard Worker     bool m_Imported;
229*89c4ff92SAndroid Build Coastguard Worker     bool m_IsImportEnabled;
230*89c4ff92SAndroid Build Coastguard Worker };
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker class ClSubTensorHandle : public IClTensorHandle
233*89c4ff92SAndroid Build Coastguard Worker {
234*89c4ff92SAndroid Build Coastguard Worker public:
ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)235*89c4ff92SAndroid Build Coastguard Worker     ClSubTensorHandle(IClTensorHandle* parent,
236*89c4ff92SAndroid Build Coastguard Worker                       const arm_compute::TensorShape& shape,
237*89c4ff92SAndroid Build Coastguard Worker                       const arm_compute::Coordinates& coords)
238*89c4ff92SAndroid Build Coastguard Worker     : m_Tensor(&parent->GetTensor(), shape, coords)
239*89c4ff92SAndroid Build Coastguard Worker     {
240*89c4ff92SAndroid Build Coastguard Worker         parentHandle = parent;
241*89c4ff92SAndroid Build Coastguard Worker     }
242*89c4ff92SAndroid Build Coastguard Worker 
GetTensor()243*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; }
GetTensor() const244*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; }
245*89c4ff92SAndroid Build Coastguard Worker 
Allocate()246*89c4ff92SAndroid Build Coastguard Worker     virtual void Allocate() override {}
Manage()247*89c4ff92SAndroid Build Coastguard Worker     virtual void Manage() override {}
248*89c4ff92SAndroid Build Coastguard Worker 
Map(bool blocking=true) const249*89c4ff92SAndroid Build Coastguard Worker     virtual const void* Map(bool blocking = true) const override
250*89c4ff92SAndroid Build Coastguard Worker     {
251*89c4ff92SAndroid Build Coastguard Worker         const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking);
252*89c4ff92SAndroid Build Coastguard Worker         return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes());
253*89c4ff92SAndroid Build Coastguard Worker     }
Unmap() const254*89c4ff92SAndroid Build Coastguard Worker     virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); }
255*89c4ff92SAndroid Build Coastguard Worker 
GetParent() const256*89c4ff92SAndroid Build Coastguard Worker     virtual ITensorHandle* GetParent() const override { return parentHandle; }
257*89c4ff92SAndroid Build Coastguard Worker 
GetDataType() const258*89c4ff92SAndroid Build Coastguard Worker     virtual arm_compute::DataType GetDataType() const override
259*89c4ff92SAndroid Build Coastguard Worker     {
260*89c4ff92SAndroid Build Coastguard Worker         return m_Tensor.info()->data_type();
261*89c4ff92SAndroid Build Coastguard Worker     }
262*89c4ff92SAndroid Build Coastguard Worker 
SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)263*89c4ff92SAndroid Build Coastguard Worker     virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {}
264*89c4ff92SAndroid Build Coastguard Worker 
GetStrides() const265*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetStrides() const override
266*89c4ff92SAndroid Build Coastguard Worker     {
267*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes());
268*89c4ff92SAndroid Build Coastguard Worker     }
269*89c4ff92SAndroid Build Coastguard Worker 
GetShape() const270*89c4ff92SAndroid Build Coastguard Worker     TensorShape GetShape() const override
271*89c4ff92SAndroid Build Coastguard Worker     {
272*89c4ff92SAndroid Build Coastguard Worker         return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape());
273*89c4ff92SAndroid Build Coastguard Worker     }
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker private:
276*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyOutTo(void * memory) const277*89c4ff92SAndroid Build Coastguard Worker     void CopyOutTo(void* memory) const override
278*89c4ff92SAndroid Build Coastguard Worker     {
279*89c4ff92SAndroid Build Coastguard Worker         const_cast<ClSubTensorHandle*>(this)->Map(true);
280*89c4ff92SAndroid Build Coastguard Worker         switch(this->GetDataType())
281*89c4ff92SAndroid Build Coastguard Worker         {
282*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
283*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
284*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<float*>(memory));
285*89c4ff92SAndroid Build Coastguard Worker                 break;
286*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
287*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
288*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
289*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<uint8_t*>(memory));
290*89c4ff92SAndroid Build Coastguard Worker                 break;
291*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
292*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
293*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<armnn::Half*>(memory));
294*89c4ff92SAndroid Build Coastguard Worker                 break;
295*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
296*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
297*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
298*89c4ff92SAndroid Build Coastguard Worker             armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
299*89c4ff92SAndroid Build Coastguard Worker                                                              static_cast<int8_t*>(memory));
300*89c4ff92SAndroid Build Coastguard Worker                 break;
301*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
302*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
303*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
304*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int16_t*>(memory));
305*89c4ff92SAndroid Build Coastguard Worker                 break;
306*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
307*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(),
308*89c4ff92SAndroid Build Coastguard Worker                                                                  static_cast<int32_t*>(memory));
309*89c4ff92SAndroid Build Coastguard Worker                 break;
310*89c4ff92SAndroid Build Coastguard Worker             default:
311*89c4ff92SAndroid Build Coastguard Worker             {
312*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
313*89c4ff92SAndroid Build Coastguard Worker             }
314*89c4ff92SAndroid Build Coastguard Worker         }
315*89c4ff92SAndroid Build Coastguard Worker         const_cast<ClSubTensorHandle*>(this)->Unmap();
316*89c4ff92SAndroid Build Coastguard Worker     }
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker     // Only used for testing
CopyInFrom(const void * memory)319*89c4ff92SAndroid Build Coastguard Worker     void CopyInFrom(const void* memory) override
320*89c4ff92SAndroid Build Coastguard Worker     {
321*89c4ff92SAndroid Build Coastguard Worker         this->Map(true);
322*89c4ff92SAndroid Build Coastguard Worker         switch(this->GetDataType())
323*89c4ff92SAndroid Build Coastguard Worker         {
324*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F32:
325*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory),
326*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
327*89c4ff92SAndroid Build Coastguard Worker                 break;
328*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::U8:
329*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8:
330*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory),
331*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
332*89c4ff92SAndroid Build Coastguard Worker                 break;
333*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::F16:
334*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory),
335*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
336*89c4ff92SAndroid Build Coastguard Worker                 break;
337*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8:
338*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
339*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QASYMM8_SIGNED:
340*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory),
341*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
342*89c4ff92SAndroid Build Coastguard Worker                 break;
343*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S16:
344*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::QSYMM16:
345*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory),
346*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
347*89c4ff92SAndroid Build Coastguard Worker                 break;
348*89c4ff92SAndroid Build Coastguard Worker             case arm_compute::DataType::S32:
349*89c4ff92SAndroid Build Coastguard Worker                 armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory),
350*89c4ff92SAndroid Build Coastguard Worker                                                                  this->GetTensor());
351*89c4ff92SAndroid Build Coastguard Worker                 break;
352*89c4ff92SAndroid Build Coastguard Worker             default:
353*89c4ff92SAndroid Build Coastguard Worker             {
354*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::UnimplementedException();
355*89c4ff92SAndroid Build Coastguard Worker             }
356*89c4ff92SAndroid Build Coastguard Worker         }
357*89c4ff92SAndroid Build Coastguard Worker         this->Unmap();
358*89c4ff92SAndroid Build Coastguard Worker     }
359*89c4ff92SAndroid Build Coastguard Worker 
360*89c4ff92SAndroid Build Coastguard Worker     mutable arm_compute::CLSubTensor m_Tensor;
361*89c4ff92SAndroid Build Coastguard Worker     ITensorHandle* parentHandle = nullptr;
362*89c4ff92SAndroid Build Coastguard Worker };
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
365