xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClConstantWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClConstantWorkload.hpp"
7 
8 #include <Half.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <cl/ClTensorHandle.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 
13 #include "ClWorkloadUtils.hpp"
14 
15 namespace armnn
16 {
17 
ClConstantWorkloadValidate(const TensorInfo & output)18 arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output)
19 {
20     const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
21 
22     std::array<arm_compute::DataType,8> supportedTypes = {
23             arm_compute::DataType::F16,
24             arm_compute::DataType::F32,
25             arm_compute::DataType::QASYMM8,
26             arm_compute::DataType::QASYMM8_SIGNED,
27             arm_compute::DataType::QSYMM16,
28             arm_compute::DataType::QSYMM8,
29             arm_compute::DataType::QSYMM8_PER_CHANNEL,
30             arm_compute::DataType::S32
31     };
32     auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type());
33 
34     if (it != end(supportedTypes))
35     {
36         return arm_compute::Status{};
37     }
38     else
39     {
40         return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported DataType"};
41     }
42 }
43 
ClConstantWorkload(const ConstantQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext &)44 ClConstantWorkload::ClConstantWorkload(const ConstantQueueDescriptor& descriptor,
45                                        const WorkloadInfo& info,
46                                        const arm_compute::CLCompileContext&)
47     : ClBaseWorkload<ConstantQueueDescriptor>(descriptor, info)
48     , m_RanOnce(false)
49 {
50 }
51 
Execute() const52 void ClConstantWorkload::Execute() const
53 {
54     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClConstantWorkload_Execute", this->GetGuid());
55 
56     // The intermediate tensor held by the corresponding layer output handler can be initialised with the given data
57     // on the first inference, then reused for subsequent inferences.
58     // The initialisation cannot happen at workload construction time since the ACL kernel for the next layer may not
59     // have been configured at the time.
60     if (!m_RanOnce)
61     {
62         const ConstantQueueDescriptor& data = this->m_Data;
63 
64         ARMNN_ASSERT(data.m_LayerOutput != nullptr);
65         arm_compute::CLTensor& output = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetTensor();
66         arm_compute::DataType computeDataType = static_cast<ClTensorHandle*>(data.m_Outputs[0])->GetDataType();
67 
68         switch (computeDataType)
69         {
70             case arm_compute::DataType::F16:
71             {
72                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<Half>());
73                 break;
74             }
75             case arm_compute::DataType::F32:
76             {
77                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<float>());
78                 break;
79             }
80             case arm_compute::DataType::QASYMM8:
81             {
82                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<uint8_t>());
83                 break;
84             }
85             case arm_compute::DataType::QASYMM8_SIGNED:
86             {
87                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
88                 break;
89             }
90             case arm_compute::DataType::QSYMM16:
91             {
92                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int16_t>());
93                 break;
94             }
95             case arm_compute::DataType::QSYMM8:
96             case arm_compute::DataType::QSYMM8_PER_CHANNEL:
97             {
98                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>());
99                 break;
100             }
101             case arm_compute::DataType::S32:
102             {
103                 CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>());
104                 break;
105             }
106             default:
107             {
108                 ARMNN_ASSERT_MSG(false, "Unknown data type");
109                 break;
110             }
111         }
112 
113         m_RanOnce = true;
114     }
115 }
116 
117 } //namespace armnn
118