xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonConstantWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "NeonConstantWorkload.hpp"
7 
8 #include <arm_compute/core/Types.h>
9 #include <BFloat16.hpp>
10 #include <Half.hpp>
11 #include <aclCommon/ArmComputeTensorUtils.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 #include <neon/NeonTensorHandle.hpp>
14 #include <armnn/backends/TensorHandle.hpp>
15 #include "NeonBaseWorkload.hpp"
16 
17 namespace armnn
18 {
19 
NeonConstantWorkloadValidate(const TensorInfo & output)20 arm_compute::Status NeonConstantWorkloadValidate(const TensorInfo& output)
21 {
22     const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23 
24     std::array<arm_compute::DataType,9> supportedTypes = {
25             arm_compute::DataType::BFLOAT16,
26             arm_compute::DataType::F16,
27             arm_compute::DataType::F32,
28             arm_compute::DataType::QASYMM8,
29             arm_compute::DataType::QASYMM8_SIGNED,
30             arm_compute::DataType::QSYMM16,
31             arm_compute::DataType::QSYMM8,
32             arm_compute::DataType::QSYMM8_PER_CHANNEL,
33             arm_compute::DataType::S32
34     };
35     auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
36 
37     if (it != end(supportedTypes))
38     {
39         return arm_compute::Status{};
40     }
41     else
42     {
43         return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
44     }
45 }
46 
NeonConstantWorkload(const ConstantQueueDescriptor & descriptor,const WorkloadInfo & info)47 NeonConstantWorkload::NeonConstantWorkload(const ConstantQueueDescriptor& descriptor,
48                                            const WorkloadInfo& info)
49     : NeonBaseWorkload<ConstantQueueDescriptor>(descriptor, info)
50     , m_RanOnce(false)
51 {
52 }
53 
Execute() const54 void NeonConstantWorkload::Execute() const
55 {
56     ARMNN_SCOPED_PROFILING_EVENT_NEON_GUID("NeonConstantWorkload_Execute", this->GetGuid());
57 
58     using namespace armcomputetensorutils;
59 
60     // The intermediate tensor held by the corresponding layer output handler can be initialised with the
61     // given data on the first inference, then reused for subsequent inferences.
62     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer
63     // may not have been configured at the time.
64     if (!m_RanOnce)
65     {
66         const ConstantQueueDescriptor& data = this->m_Data;
67 
68         ARMNN_ASSERT(data.m_LayerOutput != nullptr);
69         arm_compute::ITensor& output =
70             PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetTensor();
71         arm_compute::DataType computeDataType =
72             PolymorphicDowncast<NeonTensorHandle*>(data.m_Outputs[0])->GetDataType();
73 
74         switch (computeDataType)
75         {
76             case arm_compute::DataType::BFLOAT16:
77             {
78                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<BFloat16>(), output);
79                 break;
80             }
81             case arm_compute::DataType::F16:
82             {
83                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<Half>(), output);
84                 break;
85             }
86             case arm_compute::DataType::F32:
87             {
88                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<float>(), output);
89                 break;
90             }
91             case arm_compute::DataType::QASYMM8:
92             {
93                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<uint8_t>(), output);
94                 break;
95             }
96             case arm_compute::DataType::QASYMM8_SIGNED:
97             {
98                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
99                 break;
100             }
101             case arm_compute::DataType::QSYMM16:
102             {
103                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int16_t>(), output);
104                 break;
105             }
106             case arm_compute::DataType::QSYMM8:
107             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
108             {
109                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int8_t>(), output);
110                 break;
111             }
112             case arm_compute::DataType::S32:
113             {
114                 CopyArmComputeITensorData(data.m_LayerOutput->GetConstTensor<int32_t>(), output);
115                 break;
116             }
117             default:
118             {
119                 ARMNN_ASSERT_MSG(false, "Unknown data type");
120                 break;
121             }
122         }
123 
124         m_RanOnce = true;
125     }
126 }
127 
128 } //namespace armnn
129