1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker namespace armnn
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker
GetUnpaddedTensorStrides(const TensorInfo & tensorInfo)15*89c4ff92SAndroid Build Coastguard Worker TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker TensorShape shape(tensorInfo.GetShape());
18*89c4ff92SAndroid Build Coastguard Worker auto size = GetDataTypeSize(tensorInfo.GetDataType());
19*89c4ff92SAndroid Build Coastguard Worker auto runningSize = size;
20*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> strides(shape.GetNumDimensions());
21*89c4ff92SAndroid Build Coastguard Worker auto lastIdx = shape.GetNumDimensions()-1;
22*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i=0; i < lastIdx ; i++)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker strides[lastIdx-i] = runningSize;
25*89c4ff92SAndroid Build Coastguard Worker runningSize *= shape[lastIdx-i];
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker strides[0] = runningSize;
28*89c4ff92SAndroid Build Coastguard Worker return TensorShape(shape.GetNumDimensions(), strides.data());
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker
ConstTensorHandle(const TensorInfo & tensorInfo)31*89c4ff92SAndroid Build Coastguard Worker ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo)
32*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
33*89c4ff92SAndroid Build Coastguard Worker , m_Memory(nullptr)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker template <>
GetConstTensor() const38*89c4ff92SAndroid Build Coastguard Worker const void* ConstTensorHandle::GetConstTensor<void>() const
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker return m_Memory;
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
TensorHandle(const TensorInfo & tensorInfo)43*89c4ff92SAndroid Build Coastguard Worker TensorHandle::TensorHandle(const TensorInfo& tensorInfo)
44*89c4ff92SAndroid Build Coastguard Worker : ConstTensorHandle(tensorInfo)
45*89c4ff92SAndroid Build Coastguard Worker , m_MutableMemory(nullptr)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker template <>
GetTensor() const50*89c4ff92SAndroid Build Coastguard Worker void* TensorHandle::GetTensor<void>() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker return m_MutableMemory;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker
ScopedTensorHandle(const TensorInfo & tensorInfo)55*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo)
56*89c4ff92SAndroid Build Coastguard Worker : TensorHandle(tensorInfo)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker
ScopedTensorHandle(const ConstTensor & tensor)60*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor)
61*89c4ff92SAndroid Build Coastguard Worker : ScopedTensorHandle(tensor.GetInfo())
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker
ScopedTensorHandle(const ConstTensorHandle & tensorHandle)66*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle)
67*89c4ff92SAndroid Build Coastguard Worker : ScopedTensorHandle(tensorHandle.GetTensorInfo())
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
ScopedTensorHandle(const ScopedTensorHandle & other)72*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other)
73*89c4ff92SAndroid Build Coastguard Worker : TensorHandle(other.GetTensorInfo())
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker CopyFrom(other);
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker
operator =(const ScopedTensorHandle & other)78*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker ::operator delete(GetTensor<void>());
81*89c4ff92SAndroid Build Coastguard Worker SetMemory(nullptr);
82*89c4ff92SAndroid Build Coastguard Worker CopyFrom(other);
83*89c4ff92SAndroid Build Coastguard Worker return *this;
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
~ScopedTensorHandle()86*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::~ScopedTensorHandle()
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker ::operator delete(GetTensor<void>());
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
Allocate()91*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::Allocate()
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker if (GetTensor<void>() == nullptr)
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker else
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle"
100*89c4ff92SAndroid Build Coastguard Worker "that already has allocated memory.");
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker
CopyOutTo(void * memory) const104*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyOutTo(void* memory) const
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
CopyInFrom(const void * memory)109*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyInFrom(const void* memory)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker
CopyFrom(const ScopedTensorHandle & other)114*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker
CopyFrom(const void * srcMemory,unsigned int numBytes)119*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(GetTensor<void>() == nullptr);
122*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker if (srcMemory)
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker Allocate();
127*89c4ff92SAndroid Build Coastguard Worker memcpy(GetTensor<void>(), srcMemory, numBytes);
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker
Allocate()131*89c4ff92SAndroid Build Coastguard Worker void PassthroughTensorHandle::Allocate()
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called");
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker
Allocate()136*89c4ff92SAndroid Build Coastguard Worker void ConstPassthroughTensorHandle::Allocate()
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called");
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
142