xref: /aosp_15_r20/external/armnn/src/dynamic/sample/SampleDynamicAdditionWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "SampleDynamicAdditionWorkload.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "SampleTensorHandle.hpp"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker namespace sdb // sample dynamic backend
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker 
GetTensorInfo(const armnn::ITensorHandle * tensorHandle)14*89c4ff92SAndroid Build Coastguard Worker inline const armnn::TensorInfo& GetTensorInfo(const armnn::ITensorHandle* tensorHandle)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     // We know that reference workloads use RefTensorHandles for inputs and outputs
17*89c4ff92SAndroid Build Coastguard Worker     const SampleTensorHandle* sampleTensorHandle =
18*89c4ff92SAndroid Build Coastguard Worker         static_cast<const SampleTensorHandle*>(tensorHandle);
19*89c4ff92SAndroid Build Coastguard Worker     return sampleTensorHandle->GetTensorInfo();
20*89c4ff92SAndroid Build Coastguard Worker }
21*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensorData(unsigned int idx,const armnn::AdditionQueueDescriptor & data)22*89c4ff92SAndroid Build Coastguard Worker const float* GetInputTensorData(unsigned int idx, const armnn::AdditionQueueDescriptor& data)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker     const armnn::ITensorHandle* tensorHandle = data.m_Inputs[idx];
25*89c4ff92SAndroid Build Coastguard Worker     return reinterpret_cast<const float*>(tensorHandle->Map());
26*89c4ff92SAndroid Build Coastguard Worker }
27*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensorData(unsigned int idx,const armnn::AdditionQueueDescriptor & data)28*89c4ff92SAndroid Build Coastguard Worker float* GetOutputTensorData(unsigned int idx, const armnn::AdditionQueueDescriptor& data)
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker     armnn::ITensorHandle* tensorHandle = data.m_Outputs[idx];
31*89c4ff92SAndroid Build Coastguard Worker     return reinterpret_cast<float*>(tensorHandle->Map());
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker 
SampleDynamicAdditionWorkload(const armnn::AdditionQueueDescriptor & descriptor,const armnn::WorkloadInfo & info)34*89c4ff92SAndroid Build Coastguard Worker SampleDynamicAdditionWorkload::SampleDynamicAdditionWorkload(const armnn::AdditionQueueDescriptor& descriptor,
35*89c4ff92SAndroid Build Coastguard Worker                                                              const armnn::WorkloadInfo& info)
36*89c4ff92SAndroid Build Coastguard Worker     : BaseWorkload(descriptor, info)
37*89c4ff92SAndroid Build Coastguard Worker {}
38*89c4ff92SAndroid Build Coastguard Worker 
Execute() const39*89c4ff92SAndroid Build Coastguard Worker void SampleDynamicAdditionWorkload::Execute() const
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& info = GetTensorInfo(m_Data.m_Inputs[0]);
42*89c4ff92SAndroid Build Coastguard Worker     unsigned int num = info.GetNumElements();
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     const float* inputData0 = GetInputTensorData(0, m_Data);
45*89c4ff92SAndroid Build Coastguard Worker     const float* inputData1 = GetInputTensorData(1, m_Data);
46*89c4ff92SAndroid Build Coastguard Worker     float* outputData       = GetOutputTensorData(0, m_Data);
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < num; ++i)
49*89c4ff92SAndroid Build Coastguard Worker     {
50*89c4ff92SAndroid Build Coastguard Worker         outputData[i] = inputData0[i] + inputData1[i];
51*89c4ff92SAndroid Build Coastguard Worker     }
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker } // namespace sdb // sample dynamic backend
55