xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/TensorHandle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker namespace armnn
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker 
GetUnpaddedTensorStrides(const TensorInfo & tensorInfo)15*89c4ff92SAndroid Build Coastguard Worker TensorShape GetUnpaddedTensorStrides(const TensorInfo& tensorInfo)
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker     TensorShape shape(tensorInfo.GetShape());
18*89c4ff92SAndroid Build Coastguard Worker     auto size = GetDataTypeSize(tensorInfo.GetDataType());
19*89c4ff92SAndroid Build Coastguard Worker     auto runningSize = size;
20*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> strides(shape.GetNumDimensions());
21*89c4ff92SAndroid Build Coastguard Worker     auto lastIdx = shape.GetNumDimensions()-1;
22*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i=0; i < lastIdx ; i++)
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         strides[lastIdx-i] = runningSize;
25*89c4ff92SAndroid Build Coastguard Worker         runningSize *= shape[lastIdx-i];
26*89c4ff92SAndroid Build Coastguard Worker     }
27*89c4ff92SAndroid Build Coastguard Worker     strides[0] = runningSize;
28*89c4ff92SAndroid Build Coastguard Worker     return TensorShape(shape.GetNumDimensions(), strides.data());
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker 
ConstTensorHandle(const TensorInfo & tensorInfo)31*89c4ff92SAndroid Build Coastguard Worker ConstTensorHandle::ConstTensorHandle(const TensorInfo& tensorInfo)
32*89c4ff92SAndroid Build Coastguard Worker : m_TensorInfo(tensorInfo)
33*89c4ff92SAndroid Build Coastguard Worker , m_Memory(nullptr)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker template <>
GetConstTensor() const38*89c4ff92SAndroid Build Coastguard Worker const void* ConstTensorHandle::GetConstTensor<void>() const
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker     return m_Memory;
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker 
TensorHandle(const TensorInfo & tensorInfo)43*89c4ff92SAndroid Build Coastguard Worker TensorHandle::TensorHandle(const TensorInfo& tensorInfo)
44*89c4ff92SAndroid Build Coastguard Worker : ConstTensorHandle(tensorInfo)
45*89c4ff92SAndroid Build Coastguard Worker , m_MutableMemory(nullptr)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker template <>
GetTensor() const50*89c4ff92SAndroid Build Coastguard Worker void* TensorHandle::GetTensor<void>() const
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     return m_MutableMemory;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker 
ScopedTensorHandle(const TensorInfo & tensorInfo)55*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const TensorInfo& tensorInfo)
56*89c4ff92SAndroid Build Coastguard Worker : TensorHandle(tensorInfo)
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker 
ScopedTensorHandle(const ConstTensor & tensor)60*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ConstTensor& tensor)
61*89c4ff92SAndroid Build Coastguard Worker : ScopedTensorHandle(tensor.GetInfo())
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     CopyFrom(tensor.GetMemoryArea(), tensor.GetNumBytes());
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker 
ScopedTensorHandle(const ConstTensorHandle & tensorHandle)66*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ConstTensorHandle& tensorHandle)
67*89c4ff92SAndroid Build Coastguard Worker : ScopedTensorHandle(tensorHandle.GetTensorInfo())
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker     CopyFrom(tensorHandle.GetConstTensor<void>(), tensorHandle.GetTensorInfo().GetNumBytes());
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker 
ScopedTensorHandle(const ScopedTensorHandle & other)72*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::ScopedTensorHandle(const ScopedTensorHandle& other)
73*89c4ff92SAndroid Build Coastguard Worker : TensorHandle(other.GetTensorInfo())
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker     CopyFrom(other);
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
operator =(const ScopedTensorHandle & other)78*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle& ScopedTensorHandle::operator=(const ScopedTensorHandle& other)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker     ::operator delete(GetTensor<void>());
81*89c4ff92SAndroid Build Coastguard Worker     SetMemory(nullptr);
82*89c4ff92SAndroid Build Coastguard Worker     CopyFrom(other);
83*89c4ff92SAndroid Build Coastguard Worker     return *this;
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker 
~ScopedTensorHandle()86*89c4ff92SAndroid Build Coastguard Worker ScopedTensorHandle::~ScopedTensorHandle()
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker     ::operator delete(GetTensor<void>());
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker 
Allocate()91*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::Allocate()
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker     if (GetTensor<void>() == nullptr)
94*89c4ff92SAndroid Build Coastguard Worker     {
95*89c4ff92SAndroid Build Coastguard Worker         SetMemory(::operator new(GetTensorInfo().GetNumBytes()));
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker     else
98*89c4ff92SAndroid Build Coastguard Worker     {
99*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("TensorHandle::Allocate Trying to allocate a TensorHandle"
100*89c4ff92SAndroid Build Coastguard Worker             "that already has allocated memory.");
101*89c4ff92SAndroid Build Coastguard Worker     }
102*89c4ff92SAndroid Build Coastguard Worker }
103*89c4ff92SAndroid Build Coastguard Worker 
CopyOutTo(void * memory) const104*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyOutTo(void* memory) const
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker     memcpy(memory, GetTensor<void>(), GetTensorInfo().GetNumBytes());
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker 
CopyInFrom(const void * memory)109*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyInFrom(const void* memory)
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker     memcpy(GetTensor<void>(), memory, GetTensorInfo().GetNumBytes());
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker 
CopyFrom(const ScopedTensorHandle & other)114*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyFrom(const ScopedTensorHandle& other)
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker     CopyFrom(other.GetTensor<void>(), other.GetTensorInfo().GetNumBytes());
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker 
CopyFrom(const void * srcMemory,unsigned int numBytes)119*89c4ff92SAndroid Build Coastguard Worker void ScopedTensorHandle::CopyFrom(const void* srcMemory, unsigned int numBytes)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(GetTensor<void>() == nullptr);
122*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(GetTensorInfo().GetNumBytes() == numBytes);
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     if (srcMemory)
125*89c4ff92SAndroid Build Coastguard Worker     {
126*89c4ff92SAndroid Build Coastguard Worker         Allocate();
127*89c4ff92SAndroid Build Coastguard Worker         memcpy(GetTensor<void>(), srcMemory, numBytes);
128*89c4ff92SAndroid Build Coastguard Worker     }
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
Allocate()131*89c4ff92SAndroid Build Coastguard Worker void PassthroughTensorHandle::Allocate()
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException("PassthroughTensorHandle::Allocate() should never be called");
134*89c4ff92SAndroid Build Coastguard Worker }
135*89c4ff92SAndroid Build Coastguard Worker 
Allocate()136*89c4ff92SAndroid Build Coastguard Worker void ConstPassthroughTensorHandle::Allocate()
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException("ConstPassthroughTensorHandle::Allocate() should never be called");
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
142