1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include <BFloat16.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <Half.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorHandle.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeTensorUtils.hpp> 14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp> 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/MemoryGroup.h> 17*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/IMemoryGroup.h> 18*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/Tensor.h> 19*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/SubTensor.h> 20*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/TensorShape.h> 21*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Coordinates.h> 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker namespace armnn 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker 26*89c4ff92SAndroid Build Coastguard Worker class NeonTensorHandle : public IAclTensorHandle 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker public: NeonTensorHandle(const TensorInfo & tensorInfo)29*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandle(const TensorInfo& tensorInfo) 30*89c4ff92SAndroid Build Coastguard Worker : m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)), 31*89c4ff92SAndroid Build Coastguard Worker m_Imported(false), 32*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled(false), 33*89c4ff92SAndroid Build Coastguard Worker m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) 34*89c4ff92SAndroid Build Coastguard Worker { 35*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo); 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,MemorySourceFlags importFlags=static_cast<MemorySourceFlags> (MemorySource::Malloc))38*89c4ff92SAndroid Build Coastguard Worker NeonTensorHandle(const TensorInfo& tensorInfo, 39*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout, 40*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc)) 41*89c4ff92SAndroid Build Coastguard Worker : m_ImportFlags(importFlags), 42*89c4ff92SAndroid Build Coastguard Worker m_Imported(false), 43*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled(false), 44*89c4ff92SAndroid Build Coastguard Worker m_TypeAlignment(GetDataTypeSize(tensorInfo.GetDataType())) 45*89c4ff92SAndroid Build Coastguard Worker 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker { 48*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::BuildArmComputeTensor(m_Tensor, tensorInfo, dataLayout); 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker GetTensor()51*89c4ff92SAndroid Build Coastguard Worker arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const52*89c4ff92SAndroid Build Coastguard Worker arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 53*89c4ff92SAndroid Build Coastguard Worker Allocate()54*89c4ff92SAndroid Build Coastguard Worker virtual void Allocate() override 55*89c4ff92SAndroid Build Coastguard Worker { 56*89c4ff92SAndroid Build Coastguard Worker // If we have enabled Importing, don't Allocate the tensor 57*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled) 58*89c4ff92SAndroid Build Coastguard Worker { 59*89c4ff92SAndroid Build Coastguard Worker armnn::armcomputetensorutils::InitialiseArmComputeTensorEmpty(m_Tensor); 60*89c4ff92SAndroid Build Coastguard Worker } 61*89c4ff92SAndroid Build Coastguard Worker }; 62*89c4ff92SAndroid Build Coastguard Worker Manage()63*89c4ff92SAndroid Build Coastguard Worker virtual void Manage() override 64*89c4ff92SAndroid Build Coastguard Worker { 65*89c4ff92SAndroid Build Coastguard Worker // If we have enabled Importing, don't manage the tensor 66*89c4ff92SAndroid Build Coastguard Worker if (!m_IsImportEnabled) 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(m_MemoryGroup != nullptr); 69*89c4ff92SAndroid Build Coastguard Worker m_MemoryGroup->manage(&m_Tensor); 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker GetParent() const73*89c4ff92SAndroid Build Coastguard Worker virtual ITensorHandle* GetParent() const override { return nullptr; } 74*89c4ff92SAndroid Build Coastguard Worker GetDataType() const75*89c4ff92SAndroid Build Coastguard Worker virtual arm_compute::DataType GetDataType() const override 76*89c4ff92SAndroid Build Coastguard Worker { 77*89c4ff92SAndroid Build Coastguard Worker return m_Tensor.info()->data_type(); 78*89c4ff92SAndroid Build Coastguard Worker } 79*89c4ff92SAndroid Build Coastguard Worker SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> & memoryGroup)80*89c4ff92SAndroid Build Coastguard Worker virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>& memoryGroup) override 81*89c4ff92SAndroid Build Coastguard Worker { 82*89c4ff92SAndroid Build Coastguard Worker m_MemoryGroup = PolymorphicPointerDowncast<arm_compute::MemoryGroup>(memoryGroup); 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker Map(bool) const85*89c4ff92SAndroid Build Coastguard Worker virtual const void* Map(bool /* blocking = true */) const override 86*89c4ff92SAndroid Build Coastguard Worker { 87*89c4ff92SAndroid Build Coastguard Worker return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 88*89c4ff92SAndroid Build Coastguard Worker } 89*89c4ff92SAndroid Build Coastguard Worker Unmap() const90*89c4ff92SAndroid Build Coastguard Worker virtual void Unmap() const override {} 91*89c4ff92SAndroid Build Coastguard Worker GetStrides() const92*89c4ff92SAndroid Build Coastguard Worker TensorShape GetStrides() const override 93*89c4ff92SAndroid Build Coastguard Worker { 94*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker GetShape() const97*89c4ff92SAndroid Build Coastguard Worker TensorShape GetShape() const override 98*89c4ff92SAndroid Build Coastguard Worker { 99*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 100*89c4ff92SAndroid Build Coastguard Worker } 101*89c4ff92SAndroid Build Coastguard Worker SetImportFlags(MemorySourceFlags importFlags)102*89c4ff92SAndroid Build Coastguard Worker void SetImportFlags(MemorySourceFlags importFlags) 103*89c4ff92SAndroid Build Coastguard Worker { 104*89c4ff92SAndroid Build Coastguard Worker m_ImportFlags = importFlags; 105*89c4ff92SAndroid Build Coastguard Worker } 106*89c4ff92SAndroid Build Coastguard Worker GetImportFlags() const107*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags GetImportFlags() const override 108*89c4ff92SAndroid Build Coastguard Worker { 109*89c4ff92SAndroid Build Coastguard Worker return m_ImportFlags; 110*89c4ff92SAndroid Build Coastguard Worker } 111*89c4ff92SAndroid Build Coastguard Worker SetImportEnabledFlag(bool importEnabledFlag)112*89c4ff92SAndroid Build Coastguard Worker void SetImportEnabledFlag(bool importEnabledFlag) 113*89c4ff92SAndroid Build Coastguard Worker { 114*89c4ff92SAndroid Build Coastguard Worker m_IsImportEnabled = importEnabledFlag; 115*89c4ff92SAndroid Build Coastguard Worker } 116*89c4ff92SAndroid Build Coastguard Worker CanBeImported(void * memory,MemorySource source)117*89c4ff92SAndroid Build Coastguard Worker bool CanBeImported(void* memory, MemorySource source) override 118*89c4ff92SAndroid Build Coastguard Worker { 119*89c4ff92SAndroid Build Coastguard Worker if (source != MemorySource::Malloc || reinterpret_cast<uintptr_t>(memory) % m_TypeAlignment) 120*89c4ff92SAndroid Build Coastguard Worker { 121*89c4ff92SAndroid Build Coastguard Worker return false; 122*89c4ff92SAndroid Build Coastguard Worker } 123*89c4ff92SAndroid Build Coastguard Worker return true; 124*89c4ff92SAndroid Build Coastguard Worker } 125*89c4ff92SAndroid Build Coastguard Worker Import(void * memory,MemorySource source)126*89c4ff92SAndroid Build Coastguard Worker virtual bool Import(void* memory, MemorySource source) override 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker if (m_ImportFlags & static_cast<MemorySourceFlags>(source)) 129*89c4ff92SAndroid Build Coastguard Worker { 130*89c4ff92SAndroid Build Coastguard Worker if (source == MemorySource::Malloc && m_IsImportEnabled) 131*89c4ff92SAndroid Build Coastguard Worker { 132*89c4ff92SAndroid Build Coastguard Worker if (!CanBeImported(memory, source)) 133*89c4ff92SAndroid Build Coastguard Worker { 134*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("NeonTensorHandle::Import Attempting to import unaligned memory"); 135*89c4ff92SAndroid Build Coastguard Worker } 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker // m_Tensor not yet Allocated 138*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && !m_Tensor.buffer()) 139*89c4ff92SAndroid Build Coastguard Worker { 140*89c4ff92SAndroid Build Coastguard Worker arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 141*89c4ff92SAndroid Build Coastguard Worker // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 142*89c4ff92SAndroid Build Coastguard Worker // with the Status error message 143*89c4ff92SAndroid Build Coastguard Worker m_Imported = bool(status); 144*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported) 145*89c4ff92SAndroid Build Coastguard Worker { 146*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(status.error_description()); 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker return m_Imported; 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker // m_Tensor.buffer() initially allocated with Allocate(). 152*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported && m_Tensor.buffer()) 153*89c4ff92SAndroid Build Coastguard Worker { 154*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException( 155*89c4ff92SAndroid Build Coastguard Worker "NeonTensorHandle::Import Attempting to import on an already allocated tensor"); 156*89c4ff92SAndroid Build Coastguard Worker } 157*89c4ff92SAndroid Build Coastguard Worker 158*89c4ff92SAndroid Build Coastguard Worker // m_Tensor.buffer() previously imported. 159*89c4ff92SAndroid Build Coastguard Worker if (m_Imported) 160*89c4ff92SAndroid Build Coastguard Worker { 161*89c4ff92SAndroid Build Coastguard Worker arm_compute::Status status = m_Tensor.allocator()->import_memory(memory); 162*89c4ff92SAndroid Build Coastguard Worker // Use the overloaded bool operator of Status to check if it worked, if not throw an exception 163*89c4ff92SAndroid Build Coastguard Worker // with the Status error message 164*89c4ff92SAndroid Build Coastguard Worker m_Imported = bool(status); 165*89c4ff92SAndroid Build Coastguard Worker if (!m_Imported) 166*89c4ff92SAndroid Build Coastguard Worker { 167*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(status.error_description()); 168*89c4ff92SAndroid Build Coastguard Worker } 169*89c4ff92SAndroid Build Coastguard Worker return m_Imported; 170*89c4ff92SAndroid Build Coastguard Worker } 171*89c4ff92SAndroid Build Coastguard Worker } 172*89c4ff92SAndroid Build Coastguard Worker else 173*89c4ff92SAndroid Build Coastguard Worker { 174*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("NeonTensorHandle::Import is disabled"); 175*89c4ff92SAndroid Build Coastguard Worker } 176*89c4ff92SAndroid Build Coastguard Worker } 177*89c4ff92SAndroid Build Coastguard Worker else 178*89c4ff92SAndroid Build Coastguard Worker { 179*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("NeonTensorHandle::Incorrect import flag"); 180*89c4ff92SAndroid Build Coastguard Worker } 181*89c4ff92SAndroid Build Coastguard Worker return false; 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker 184*89c4ff92SAndroid Build Coastguard Worker private: 185*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyOutTo(void * memory) const186*89c4ff92SAndroid Build Coastguard Worker void CopyOutTo(void* memory) const override 187*89c4ff92SAndroid Build Coastguard Worker { 188*89c4ff92SAndroid Build Coastguard Worker switch (this->GetDataType()) 189*89c4ff92SAndroid Build Coastguard Worker { 190*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 191*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 192*89c4ff92SAndroid Build Coastguard Worker static_cast<float*>(memory)); 193*89c4ff92SAndroid Build Coastguard Worker break; 194*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 195*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 196*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 197*89c4ff92SAndroid Build Coastguard Worker static_cast<uint8_t*>(memory)); 198*89c4ff92SAndroid Build Coastguard Worker break; 199*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 200*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 201*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 202*89c4ff92SAndroid Build Coastguard Worker static_cast<int8_t*>(memory)); 203*89c4ff92SAndroid Build Coastguard Worker break; 204*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::BFLOAT16: 205*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 206*89c4ff92SAndroid Build Coastguard Worker static_cast<armnn::BFloat16*>(memory)); 207*89c4ff92SAndroid Build Coastguard Worker break; 208*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 209*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 210*89c4ff92SAndroid Build Coastguard Worker static_cast<armnn::Half*>(memory)); 211*89c4ff92SAndroid Build Coastguard Worker break; 212*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 213*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 214*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 215*89c4ff92SAndroid Build Coastguard Worker static_cast<int16_t*>(memory)); 216*89c4ff92SAndroid Build Coastguard Worker break; 217*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 218*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 219*89c4ff92SAndroid Build Coastguard Worker static_cast<int32_t*>(memory)); 220*89c4ff92SAndroid Build Coastguard Worker break; 221*89c4ff92SAndroid Build Coastguard Worker default: 222*89c4ff92SAndroid Build Coastguard Worker { 223*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker } 226*89c4ff92SAndroid Build Coastguard Worker } 227*89c4ff92SAndroid Build Coastguard Worker 228*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyInFrom(const void * memory)229*89c4ff92SAndroid Build Coastguard Worker void CopyInFrom(const void* memory) override 230*89c4ff92SAndroid Build Coastguard Worker { 231*89c4ff92SAndroid Build Coastguard Worker switch (this->GetDataType()) 232*89c4ff92SAndroid Build Coastguard Worker { 233*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 234*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 235*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 236*89c4ff92SAndroid Build Coastguard Worker break; 237*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 238*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 239*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 240*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 241*89c4ff92SAndroid Build Coastguard Worker break; 242*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 243*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 244*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8_PER_CHANNEL: 245*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 246*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 247*89c4ff92SAndroid Build Coastguard Worker break; 248*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::BFLOAT16: 249*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::BFloat16*>(memory), 250*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 251*89c4ff92SAndroid Build Coastguard Worker break; 252*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F16: 253*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const armnn::Half*>(memory), 254*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 255*89c4ff92SAndroid Build Coastguard Worker break; 256*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 257*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 258*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 259*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 260*89c4ff92SAndroid Build Coastguard Worker break; 261*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 262*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 263*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 264*89c4ff92SAndroid Build Coastguard Worker break; 265*89c4ff92SAndroid Build Coastguard Worker default: 266*89c4ff92SAndroid Build Coastguard Worker { 267*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 268*89c4ff92SAndroid Build Coastguard Worker } 269*89c4ff92SAndroid Build Coastguard Worker } 270*89c4ff92SAndroid Build Coastguard Worker } 271*89c4ff92SAndroid Build Coastguard Worker 272*89c4ff92SAndroid Build Coastguard Worker arm_compute::Tensor m_Tensor; 273*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<arm_compute::MemoryGroup> m_MemoryGroup; 274*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags m_ImportFlags; 275*89c4ff92SAndroid Build Coastguard Worker bool m_Imported; 276*89c4ff92SAndroid Build Coastguard Worker bool m_IsImportEnabled; 277*89c4ff92SAndroid Build Coastguard Worker const uintptr_t m_TypeAlignment; 278*89c4ff92SAndroid Build Coastguard Worker }; 279*89c4ff92SAndroid Build Coastguard Worker 280*89c4ff92SAndroid Build Coastguard Worker class NeonSubTensorHandle : public IAclTensorHandle 281*89c4ff92SAndroid Build Coastguard Worker { 282*89c4ff92SAndroid Build Coastguard Worker public: NeonSubTensorHandle(IAclTensorHandle * parent,const arm_compute::TensorShape & shape,const arm_compute::Coordinates & coords)283*89c4ff92SAndroid Build Coastguard Worker NeonSubTensorHandle(IAclTensorHandle* parent, 284*89c4ff92SAndroid Build Coastguard Worker const arm_compute::TensorShape& shape, 285*89c4ff92SAndroid Build Coastguard Worker const arm_compute::Coordinates& coords) 286*89c4ff92SAndroid Build Coastguard Worker : m_Tensor(&parent->GetTensor(), shape, coords) 287*89c4ff92SAndroid Build Coastguard Worker { 288*89c4ff92SAndroid Build Coastguard Worker parentHandle = parent; 289*89c4ff92SAndroid Build Coastguard Worker } 290*89c4ff92SAndroid Build Coastguard Worker GetTensor()291*89c4ff92SAndroid Build Coastguard Worker arm_compute::ITensor& GetTensor() override { return m_Tensor; } GetTensor() const292*89c4ff92SAndroid Build Coastguard Worker arm_compute::ITensor const& GetTensor() const override { return m_Tensor; } 293*89c4ff92SAndroid Build Coastguard Worker Allocate()294*89c4ff92SAndroid Build Coastguard Worker virtual void Allocate() override {} Manage()295*89c4ff92SAndroid Build Coastguard Worker virtual void Manage() override {} 296*89c4ff92SAndroid Build Coastguard Worker GetParent() const297*89c4ff92SAndroid Build Coastguard Worker virtual ITensorHandle* GetParent() const override { return parentHandle; } 298*89c4ff92SAndroid Build Coastguard Worker GetDataType() const299*89c4ff92SAndroid Build Coastguard Worker virtual arm_compute::DataType GetDataType() const override 300*89c4ff92SAndroid Build Coastguard Worker { 301*89c4ff92SAndroid Build Coastguard Worker return m_Tensor.info()->data_type(); 302*89c4ff92SAndroid Build Coastguard Worker } 303*89c4ff92SAndroid Build Coastguard Worker SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup> &)304*89c4ff92SAndroid Build Coastguard Worker virtual void SetMemoryGroup(const std::shared_ptr<arm_compute::IMemoryGroup>&) override {} 305*89c4ff92SAndroid Build Coastguard Worker Map(bool) const306*89c4ff92SAndroid Build Coastguard Worker virtual const void* Map(bool /* blocking = true */) const override 307*89c4ff92SAndroid Build Coastguard Worker { 308*89c4ff92SAndroid Build Coastguard Worker return static_cast<const void*>(m_Tensor.buffer() + m_Tensor.info()->offset_first_element_in_bytes()); 309*89c4ff92SAndroid Build Coastguard Worker } Unmap() const310*89c4ff92SAndroid Build Coastguard Worker virtual void Unmap() const override {} 311*89c4ff92SAndroid Build Coastguard Worker GetStrides() const312*89c4ff92SAndroid Build Coastguard Worker TensorShape GetStrides() const override 313*89c4ff92SAndroid Build Coastguard Worker { 314*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetStrides(m_Tensor.info()->strides_in_bytes()); 315*89c4ff92SAndroid Build Coastguard Worker } 316*89c4ff92SAndroid Build Coastguard Worker GetShape() const317*89c4ff92SAndroid Build Coastguard Worker TensorShape GetShape() const override 318*89c4ff92SAndroid Build Coastguard Worker { 319*89c4ff92SAndroid Build Coastguard Worker return armcomputetensorutils::GetShape(m_Tensor.info()->tensor_shape()); 320*89c4ff92SAndroid Build Coastguard Worker } 321*89c4ff92SAndroid Build Coastguard Worker 322*89c4ff92SAndroid Build Coastguard Worker private: 323*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyOutTo(void * memory) const324*89c4ff92SAndroid Build Coastguard Worker void CopyOutTo(void* memory) const override 325*89c4ff92SAndroid Build Coastguard Worker { 326*89c4ff92SAndroid Build Coastguard Worker switch (this->GetDataType()) 327*89c4ff92SAndroid Build Coastguard Worker { 328*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 329*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 330*89c4ff92SAndroid Build Coastguard Worker static_cast<float*>(memory)); 331*89c4ff92SAndroid Build Coastguard Worker break; 332*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 333*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 334*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 335*89c4ff92SAndroid Build Coastguard Worker static_cast<uint8_t*>(memory)); 336*89c4ff92SAndroid Build Coastguard Worker break; 337*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 338*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 339*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 340*89c4ff92SAndroid Build Coastguard Worker static_cast<int8_t*>(memory)); 341*89c4ff92SAndroid Build Coastguard Worker break; 342*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 343*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 344*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 345*89c4ff92SAndroid Build Coastguard Worker static_cast<int16_t*>(memory)); 346*89c4ff92SAndroid Build Coastguard Worker break; 347*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 348*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(this->GetTensor(), 349*89c4ff92SAndroid Build Coastguard Worker static_cast<int32_t*>(memory)); 350*89c4ff92SAndroid Build Coastguard Worker break; 351*89c4ff92SAndroid Build Coastguard Worker default: 352*89c4ff92SAndroid Build Coastguard Worker { 353*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 354*89c4ff92SAndroid Build Coastguard Worker } 355*89c4ff92SAndroid Build Coastguard Worker } 356*89c4ff92SAndroid Build Coastguard Worker } 357*89c4ff92SAndroid Build Coastguard Worker 358*89c4ff92SAndroid Build Coastguard Worker // Only used for testing CopyInFrom(const void * memory)359*89c4ff92SAndroid Build Coastguard Worker void CopyInFrom(const void* memory) override 360*89c4ff92SAndroid Build Coastguard Worker { 361*89c4ff92SAndroid Build Coastguard Worker switch (this->GetDataType()) 362*89c4ff92SAndroid Build Coastguard Worker { 363*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::F32: 364*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const float*>(memory), 365*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 366*89c4ff92SAndroid Build Coastguard Worker break; 367*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::U8: 368*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8: 369*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const uint8_t*>(memory), 370*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 371*89c4ff92SAndroid Build Coastguard Worker break; 372*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM8: 373*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QASYMM8_SIGNED: 374*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int8_t*>(memory), 375*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 376*89c4ff92SAndroid Build Coastguard Worker break; 377*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S16: 378*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::QSYMM16: 379*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int16_t*>(memory), 380*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 381*89c4ff92SAndroid Build Coastguard Worker break; 382*89c4ff92SAndroid Build Coastguard Worker case arm_compute::DataType::S32: 383*89c4ff92SAndroid Build Coastguard Worker armcomputetensorutils::CopyArmComputeITensorData(static_cast<const int32_t*>(memory), 384*89c4ff92SAndroid Build Coastguard Worker this->GetTensor()); 385*89c4ff92SAndroid Build Coastguard Worker break; 386*89c4ff92SAndroid Build Coastguard Worker default: 387*89c4ff92SAndroid Build Coastguard Worker { 388*89c4ff92SAndroid Build Coastguard Worker throw armnn::UnimplementedException(); 389*89c4ff92SAndroid Build Coastguard Worker } 390*89c4ff92SAndroid Build Coastguard Worker } 391*89c4ff92SAndroid Build Coastguard Worker } 392*89c4ff92SAndroid Build Coastguard Worker 393*89c4ff92SAndroid Build Coastguard Worker arm_compute::SubTensor m_Tensor; 394*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* parentHandle = nullptr; 395*89c4ff92SAndroid Build Coastguard Worker }; 396*89c4ff92SAndroid Build Coastguard Worker 397*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 398