1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClFullyConnectedWorkload.hpp"
7 #include <cl/ClTensorHandle.hpp>
8 #include <armnn/backends/TensorHandle.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 #include <aclCommon/ArmComputeUtils.hpp>
11 #include <cl/ClLayerSupport.hpp>
12
13 #include "ClWorkloadUtils.hpp"
14
15 namespace armnn
16 {
17 using namespace armcomputetensorutils;
18
ClFullyConnectedWorkloadValidate(const TensorInfo & input,const TensorInfo & output,const TensorInfo & weights,const Optional<TensorInfo> & biases,const FullyConnectedDescriptor & descriptor,const ActivationDescriptor * activationDescriptor)19 arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
20 const TensorInfo& output,
21 const TensorInfo& weights,
22 const Optional<TensorInfo>& biases,
23 const FullyConnectedDescriptor& descriptor,
24 const ActivationDescriptor* activationDescriptor)
25 {
26 const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
27 const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
28 arm_compute::TensorInfo aclWeights = BuildArmComputeTensorInfo(weights);
29 aclWeights.set_are_values_constant(weights.IsConstant());
30
31 arm_compute::TensorInfo aclBiases;
32 arm_compute::TensorInfo* optionalAclBiases = nullptr;
33 if (descriptor.m_BiasEnabled)
34 {
35 ARMNN_ASSERT(biases.has_value());
36 // Same for bias as weights. We don't currently support non const.
37 if (!biases.value().IsConstant())
38 {
39 return arm_compute::Status{arm_compute::ErrorCode::RUNTIME_ERROR,
40 "Arm NN ClFullyConnectedWorkload does not support non constant bias."};
41 }
42 aclBiases = BuildArmComputeTensorInfo(biases.value());
43 aclBiases.set_are_values_constant(biases.value().IsConstant());
44 optionalAclBiases = &aclBiases;
45 }
46
47 const arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo =
48 ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(descriptor, activationDescriptor);
49 return arm_compute::CLFullyConnectedLayer::validate(&aclInput,
50 &aclWeights,
51 optionalAclBiases,
52 &aclOutput,
53 fullyConnectedLayerInfo);
54 }
55
ClFullyConnectedWorkload(const FullyConnectedQueueDescriptor & descriptor,const WorkloadInfo & info,std::shared_ptr<arm_compute::MemoryManagerOnDemand> & memoryManager,const arm_compute::CLCompileContext & clCompileContext)56 ClFullyConnectedWorkload::ClFullyConnectedWorkload(
57 const FullyConnectedQueueDescriptor& descriptor,
58 const WorkloadInfo& info,
59 std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager,
60 const arm_compute::CLCompileContext& clCompileContext)
61 : ClBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info), m_FullyConnectedLayer(memoryManager)
62 {
63 // Add details for profiling output
64 WorkloadInfo detailsInfo;
65
66 detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
67 detailsInfo.m_OutputTensorInfos = info.m_OutputTensorInfos;
68 detailsInfo.m_WeightsTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[1]);
69 if (descriptor.m_Parameters.m_BiasEnabled)
70 {
71 detailsInfo.m_BiasTensorInfo = armnn::Optional<armnn::TensorInfo>(info.m_InputTensorInfos[2]);
72 }
73
74 // Report Profiling Details
75 ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClFullyConnectedWorkload_Construct",
76 descriptor.m_Parameters,
77 detailsInfo,
78 this->GetGuid());
79
80 m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", descriptor.m_Parameters.GetNumInputs(),
81 1);
82
83 arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
84 arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
85 arm_compute::ICLTensor& weights = PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
86
87 arm_compute::ICLTensor* bias = nullptr;
88 if (m_Data.m_Parameters.m_BiasEnabled)
89 {
90 bias = &PolymorphicDowncast<IClTensorHandle*>(m_Data.m_Inputs[2])->GetTensor();
91 }
92
93 const arm_compute::ActivationLayerInfo activationInfo = ConvertAdditionalInfoToAclActivationLayerInfo(descriptor);
94
95 arm_compute::FullyConnectedLayerInfo fc_info =
96 ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(descriptor.m_Parameters,
97 activationInfo);
98
99 {
100 ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClFullyConnectedWorkload_configure");
101 m_FullyConnectedLayer.configure(clCompileContext,
102 &input,
103 &weights,
104 bias,
105 &output,
106 fc_info);
107 }
108 }
109
Execute() const110 void ClFullyConnectedWorkload::Execute() const
111 {
112 ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClFullyConnectedWorkload_Execute", this->GetGuid());
113 RunClFunction(m_FullyConnectedLayer, CHECK_LOCATION());
114 }
115
116 } //namespace armnn
117