1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorHandle.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLTensor.h> 15*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLSubTensor.h> 16*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/IMemoryGroup.h> 17*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/MemoryGroup.h> 18*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/TensorShape.h> 19*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Coordinates.h> 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/IClTensorHandle.hpp> 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker namespace armnn 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker class ClTensorHandle : public IClTensorHandle 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker public: ClTensorHandle(const TensorInfo & tensorInfo)29*89c4ff92SAndroid Build Coastguard Worker ClTensorHandle(const TensorInfo& tensorInfo) 30*89c4ff92SAndroid Build Coastguard Worker : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined)), 31*89c4ff92SAndroid Build Coastguard Worker m_Imported(false), 32*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled(false) 33*89c4ff92SAndroid Build Coastguard Worker { 34*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker ClTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Undefined))37*89c4ff92SAndroid Build Coastguard Worker ClTensorHandle(const TensorInfo& tensorInfo, 38*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout, 39*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Undefined)) 40*89c4ff92SAndroid Build Coastguard Worker : m_ImportFlags(importFlags), 41*89c4ff92SAndroid Build Coastguard Worker m_Imported(false), 42*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled(false) 43*89c4ff92SAndroid Build Coastguard Worker { 44*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker GetTensor()47*89c4ff92SAndroid Build Coastguard Worker arm_compute::CLTensor& GetTensor() override { return m_Tensor; } GetTensor() const48*89c4ff92SAndroid Build Coastguard Worker arm_compute::CLTensor const& GetTensor() const override { return m_Tensor; } Allocate()49*89c4ff92SAndroid Build Coastguard Worker virtual void Allocate() override 50*89c4ff92SAndroid Build Coastguard Worker { 51*89c4ff92SAndroid Build Coastguard Worker // If we have enabled Importing, don't allocate the tensor 52*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled) 53*89c4ff92SAndroid Build Coastguard Worker { 54*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ClTensorHandle::Attempting to allocate memory when importing"); 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker else 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); 59*89c4ff92SAndroid Build Coastguard Worker } 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker } 62*89c4ff92SAndroid Build Coastguard Worker Manage()63*89c4ff92SAndroid Build Coastguard Worker virtual void Manage() override 64*89c4ff92SAndroid Build Coastguard Worker { 65*89c4ff92SAndroid Build Coastguard Worker // If we have enabled Importing, don't manage the tensor 66*89c4ff92SAndroid Build Coastguard Worker if (m_IsImportEnabled) 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ClTensorHandle::Attempting to manage memory when importing"); 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker else 71*89c4ff92SAndroid Build Coastguard Worker { 72*89c4ff92SAndroid Build Coastguard Worker assert(m_MemoryGroup != nullptr); 73*89c4ff92SAndroid Build Coastguard Worker m_MemoryGroup->manage(&m_Tensor); 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker } 76*89c4ff92SAndroid Build Coastguard Worker Map(bool blocking=true) const77*89c4ff92SAndroid Build Coastguard Worker virtual const void* Map(bool blocking = true) const override 78*89c4ff92SAndroid Build Coastguard Worker { 79*89c4ff92SAndroid Build Coastguard Worker const_cast<arm_compute::CLTensor*>(&m_Tensor)->map(blocking); 80*89c4ff92SAndroid Build Coastguard Worker return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker Unmap() const83*89c4ff92SAndroid Build Coastguard Worker virtual void Unmap() const override { const_cast<arm_compute::CLTensor*>(&m_Tensor)->unmap(); } 84*89c4ff92SAndroid Build Coastguard Worker GetParent() const85*89c4ff92SAndroid Build Coastguard Worker virtual ITensorHandle* GetParent() const override { return nullptr; } 86*89c4ff92SAndroid Build Coastguard Worker GetDataType() const87*89c4ff92SAndroid Build Coastguard Worker virtual arm_compute::DataType GetDataType() const override 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker return m_Tensor.info()->data_type(); 90*89c4ff92SAndroid Build Coastguard Worker } 91*89c4ff92SAndroid Build Coastguard Worker SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)92*89c4ff92SAndroid Build Coastguard Worker virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 93*89c4ff92SAndroid Build Coastguard Worker { 94*89c4ff92SAndroid Build Coastguard Worker m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker GetStrides() const97*89c4ff92SAndroid Build Coastguard Worker TensorShape GetStrides() const override 98*89c4ff92SAndroid Build Coastguard Worker { 99*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 100*89c4ff92SAndroid Build Coastguard Worker } 101*89c4ff92SAndroid Build Coastguard Worker GetShape() const102*89c4ff92SAndroid Build Coastguard Worker TensorShape GetShape() const override 103*89c4ff92SAndroid Build Coastguard Worker { 104*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 105*89c4ff92SAndroid Build Coastguard Worker } 106*89c4ff92SAndroid Build Coastguard Worker SetImportFlags(MemorySourceFlags importFlags)107*89c4ff92SAndroid Build Coastguard Worker void SetImportFlags(MemorySourceFlags importFlags) 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker m_ImportFlags = importFlags; 110*89c4ff92SAndroid Build Coastguard Worker } 111*89c4ff92SAndroid Build Coastguard Worker GetImportFlags() const112*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags GetImportFlags() const override 113*89c4ff92SAndroid Build Coastguard Worker { 114*89c4ff92SAndroid Build Coastguard Worker return m_ImportFlags; 115*89c4ff92SAndroid Build Coastguard Worker } 116*89c4ff92SAndroid Build Coastguard Worker SetImportEnabledFlag(bool importEnabledFlag)117*89c4ff92SAndroid Build Coastguard Worker void SetImportEnabledFlag(bool importEnabledFlag) 118*89c4ff92SAndroid Build Coastguard Worker { 119*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled = importEnabledFlag; 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker Import(void * memory,MemorySource source)122*89c4ff92SAndroid Build Coastguard Worker virtual bool Import(void* memory, MemorySource source) override 123*89c4ff92SAndroid Build Coastguard Worker { 124*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(memory); 125*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) 126*89c4ff92SAndroid Build Coastguard Worker { 127*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ClTensorHandle::Incorrect import flag"); 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker m_Imported = false; 130*89c4ff92SAndroid Build Coastguard Worker return false; 131*89c4ff92SAndroid Build Coastguard Worker } 132*89c4ff92SAndroid Build Coastguard Worker CanBeImported(void * memory,MemorySource source)133*89c4ff92SAndroid Build Coastguard Worker virtual bool CanBeImported(void* memory, MemorySource source) override 134*89c4ff92SAndroid Build Coastguard Worker { 135*89c4ff92SAndroid Build Coastguard Worker // This TensorHandle can never import. 136*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(memory, source); 137*89c4ff92SAndroid Build Coastguard Worker return false; 138*89c4ff92SAndroid Build Coastguard Worker } 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker private: 141*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyOutTo(void * memory) const142*89c4ff92SAndroid Build Coastguard Worker void CopyOutTo(void* memory) const override 143*89c4ff92SAndroid Build Coastguard Worker { 144*89c4ff92SAndroid Build Coastguard Worker const_cast<armnn::ClTensorHandle*>(this)->Map(true); 145*89c4ff92SAndroid Build Coastguard Worker switch(this->GetDataType()) 146*89c4ff92SAndroid Build Coastguard Worker { 147*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 148*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 149*89c4ff92SAndroid Build Coastguard Worker static_cast<float*>(memory)); 150*89c4ff92SAndroid Build Coastguard Worker break; 151*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 152*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 153*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 154*89c4ff92SAndroid Build Coastguard Worker static_cast<uint8_t*>(memory)); 155*89c4ff92SAndroid Build Coastguard Worker break; 156*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 157*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL: 158*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 159*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 160*89c4ff92SAndroid Build Coastguard Worker static_cast<int8_t*>(memory)); 161*89c4ff92SAndroid Build Coastguard Worker break; 162*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 163*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 164*89c4ff92SAndroid Build Coastguard Worker static_cast<armnn::Half*>(memory)); 165*89c4ff92SAndroid Build Coastguard Worker break; 166*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 167*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 168*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 169*89c4ff92SAndroid Build Coastguard Worker static_cast<int16_t*>(memory)); 170*89c4ff92SAndroid Build Coastguard Worker break; 171*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 172*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 173*89c4ff92SAndroid Build Coastguard Worker static_cast<int32_t*>(memory)); 174*89c4ff92SAndroid Build Coastguard Worker break; 175*89c4ff92SAndroid Build Coastguard Worker default: 176*89c4ff92SAndroid Build Coastguard Worker { 177*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker } 180*89c4ff92SAndroid Build Coastguard Worker const_cast<armnn::ClTensorHandle*>(this)->Unmap(); 181*89c4ff92SAndroid Build Coastguard Worker } 182*89c4ff92SAndroid Build Coastguard Worker 183*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyInFrom(const void * memory)184*89c4ff92SAndroid Build Coastguard Worker void CopyInFrom(const void* memory) override 185*89c4ff92SAndroid Build Coastguard Worker { 186*89c4ff92SAndroid Build Coastguard Worker this->Map(true); 187*89c4ff92SAndroid Build Coastguard Worker switch(this->GetDataType()) 188*89c4ff92SAndroid Build Coastguard Worker { 189*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 190*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 191*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 192*89c4ff92SAndroid Build Coastguard Worker break; 193*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 194*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 195*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 196*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 197*89c4ff92SAndroid Build Coastguard Worker break; 198*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 199*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 200*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 201*89c4ff92SAndroid Build Coastguard Worker break; 202*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 203*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 204*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL: 205*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 206*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 207*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 208*89c4ff92SAndroid Build Coastguard Worker break; 209*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 210*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 211*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 212*89c4ff92SAndroid Build Coastguard Worker break; 213*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 214*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 215*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 216*89c4ff92SAndroid Build Coastguard Worker break; 217*89c4ff92SAndroid Build Coastguard Worker default: 218*89c4ff92SAndroid Build Coastguard Worker { 219*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 220*89c4ff92SAndroid Build Coastguard Worker } 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker this->Unmap(); 223*89c4ff92SAndroid Build Coastguard Worker } 224*89c4ff92SAndroid Build Coastguard Worker 225*89c4ff92SAndroid Build Coastguard Worker arm_compute::CLTensor m_Tensor; 226*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 227*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags m_ImportFlags; 228*89c4ff92SAndroid Build Coastguard Worker bool m_Imported; 229*89c4ff92SAndroid Build Coastguard Worker bool m_IsImportEnabled; 230*89c4ff92SAndroid Build Coastguard Worker }; 231*89c4ff92SAndroid Build Coastguard Worker 232*89c4ff92SAndroid Build Coastguard Worker class ClSubTensorHandle : public IClTensorHandle 233*89c4ff92SAndroid Build Coastguard Worker { 234*89c4ff92SAndroid Build Coastguard Worker public: ClSubTensorHandle(IClTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)235*89c4ff92SAndroid Build Coastguard Worker ClSubTensorHandle(IClTensorHandle* parent, 236*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& shape, 237*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Coordinates& coords) 238*89c4ff92SAndroid Build Coastguard Worker : m_Tensor(&parent->GetTensor(), shape, coords) 239*89c4ff92SAndroid Build Coastguard Worker { 240*89c4ff92SAndroid Build Coastguard Worker parentHandle = parent; 241*89c4ff92SAndroid Build Coastguard Worker } 242*89c4ff92SAndroid Build Coastguard Worker GetTensor()243*89c4ff92SAndroid Build Coastguard Worker arm_compute::CLSubTensor& GetTensor() override { return m_Tensor; } GetTensor() const244*89c4ff92SAndroid Build Coastguard Worker arm_compute::CLSubTensor const& GetTensor() const override { return m_Tensor; } 245*89c4ff92SAndroid Build Coastguard Worker Allocate()246*89c4ff92SAndroid Build Coastguard Worker virtual void Allocate() override {} Manage()247*89c4ff92SAndroid Build Coastguard Worker virtual void Manage() override {} 248*89c4ff92SAndroid Build Coastguard Worker Map(bool blocking=true) const249*89c4ff92SAndroid Build Coastguard Worker virtual const void* Map(bool blocking = true) const override 250*89c4ff92SAndroid Build Coastguard Worker { 251*89c4ff92SAndroid Build Coastguard Worker const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->map(blocking); 252*89c4ff92SAndroid Build Coastguard Worker return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 253*89c4ff92SAndroid Build Coastguard Worker } Unmap() const254*89c4ff92SAndroid Build Coastguard Worker virtual void Unmap() const override { const_cast<arm_compute::CLSubTensor*>(&m_Tensor)->unmap(); } 255*89c4ff92SAndroid Build Coastguard Worker GetParent() const256*89c4ff92SAndroid Build Coastguard Worker virtual ITensorHandle* GetParent() const override { return parentHandle; } 257*89c4ff92SAndroid Build Coastguard Worker GetDataType() const258*89c4ff92SAndroid Build Coastguard Worker virtual arm_compute::DataType GetDataType() const override 259*89c4ff92SAndroid Build Coastguard Worker { 260*89c4ff92SAndroid Build Coastguard Worker return m_Tensor.info()->data_type(); 261*89c4ff92SAndroid Build Coastguard Worker } 262*89c4ff92SAndroid Build Coastguard Worker SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)263*89c4ff92SAndroid Build Coastguard Worker virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 264*89c4ff92SAndroid Build Coastguard Worker GetStrides() const265*89c4ff92SAndroid Build Coastguard Worker TensorShape GetStrides() const override 266*89c4ff92SAndroid Build Coastguard Worker { 267*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 268*89c4ff92SAndroid Build Coastguard Worker } 269*89c4ff92SAndroid Build Coastguard Worker GetShape() const270*89c4ff92SAndroid Build Coastguard Worker TensorShape GetShape() const override 271*89c4ff92SAndroid Build Coastguard Worker { 272*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 273*89c4ff92SAndroid Build Coastguard Worker } 274*89c4ff92SAndroid Build Coastguard Worker 275*89c4ff92SAndroid Build Coastguard Worker private: 276*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyOutTo(void * memory) const277*89c4ff92SAndroid Build Coastguard Worker void CopyOutTo(void* memory) const override 278*89c4ff92SAndroid Build Coastguard Worker { 279*89c4ff92SAndroid Build Coastguard Worker const_cast<ClSubTensorHandle*>(this)->Map(true); 280*89c4ff92SAndroid Build Coastguard Worker switch(this->GetDataType()) 281*89c4ff92SAndroid Build Coastguard Worker { 282*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 283*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 284*89c4ff92SAndroid Build Coastguard Worker static_cast<float*>(memory)); 285*89c4ff92SAndroid Build Coastguard Worker break; 286*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 287*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 288*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 289*89c4ff92SAndroid Build Coastguard Worker static_cast<uint8_t*>(memory)); 290*89c4ff92SAndroid Build Coastguard Worker break; 291*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 292*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 293*89c4ff92SAndroid Build Coastguard Worker static_cast<armnn::Half*>(memory)); 294*89c4ff92SAndroid Build Coastguard Worker break; 295*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 296*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL: 297*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 298*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 299*89c4ff92SAndroid Build Coastguard Worker static_cast<int8_t*>(memory)); 300*89c4ff92SAndroid Build Coastguard Worker break; 301*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 302*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 303*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 304*89c4ff92SAndroid Build Coastguard Worker static_cast<int16_t*>(memory)); 305*89c4ff92SAndroid Build Coastguard Worker break; 306*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 307*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 308*89c4ff92SAndroid Build Coastguard Worker static_cast<int32_t*>(memory)); 309*89c4ff92SAndroid Build Coastguard Worker break; 310*89c4ff92SAndroid Build Coastguard Worker default: 311*89c4ff92SAndroid Build Coastguard Worker { 312*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 313*89c4ff92SAndroid Build Coastguard Worker } 314*89c4ff92SAndroid Build Coastguard Worker } 315*89c4ff92SAndroid Build Coastguard Worker const_cast<ClSubTensorHandle*>(this)->Unmap(); 316*89c4ff92SAndroid Build Coastguard Worker } 317*89c4ff92SAndroid Build Coastguard Worker 318*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyInFrom(const void * memory)319*89c4ff92SAndroid Build Coastguard Worker void CopyInFrom(const void* memory) override 320*89c4ff92SAndroid Build Coastguard Worker { 321*89c4ff92SAndroid Build Coastguard Worker this->Map(true); 322*89c4ff92SAndroid Build Coastguard Worker switch(this->GetDataType()) 323*89c4ff92SAndroid Build Coastguard Worker { 324*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 325*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 326*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 327*89c4ff92SAndroid Build Coastguard Worker break; 328*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 329*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 330*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 331*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 332*89c4ff92SAndroid Build Coastguard Worker break; 333*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 334*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 335*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 336*89c4ff92SAndroid Build Coastguard Worker break; 337*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 338*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL: 339*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 340*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 341*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 342*89c4ff92SAndroid Build Coastguard Worker break; 343*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 344*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 345*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 346*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 347*89c4ff92SAndroid Build Coastguard Worker break; 348*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 349*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 350*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 351*89c4ff92SAndroid Build Coastguard Worker break; 352*89c4ff92SAndroid Build Coastguard Worker default: 353*89c4ff92SAndroid Build Coastguard Worker { 354*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 355*89c4ff92SAndroid Build Coastguard Worker } 356*89c4ff92SAndroid Build Coastguard Worker } 357*89c4ff92SAndroid Build Coastguard Worker this->Unmap(); 358*89c4ff92SAndroid Build Coastguard Worker } 359*89c4ff92SAndroid Build Coastguard Worker 360*89c4ff92SAndroid Build Coastguard Worker mutable arm_compute::CLSubTensor m_Tensor; 361*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* parentHandle = nullptr; 362*89c4ff92SAndroid Build Coastguard Worker }; 363*89c4ff92SAndroid Build Coastguard Worker 364*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 365