xref: /aosp_15_r20/external/armnn/src/backends/neon/workloads/NeonFullyConnectedWorkload.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "NeonBaseWorkload.hpp"
9 
10 #include <arm_compute/core/Error.h>
11 #include <arm_compute/runtime/IFunction.h>
12 #include <arm_compute/runtime/MemoryManagerOnDemand.h>
13 #include <arm_compute/runtime/Tensor.h>
14 
15 #include <memory>
16 
17 namespace armnn
18 {
19 
20 arm_compute::Status NeonFullyConnectedWorkloadValidate(const TensorInfo& input,
21                                                        const TensorInfo& output,
22                                                        const TensorInfo& weights,
23                                                        const Optional<TensorInfo>& biases,
24                                                        const FullyConnectedDescriptor& descriptor,
25                                                        const ActivationDescriptor* activationDescriptor = nullptr);
26 
27 class NeonFullyConnectedWorkload : public NeonBaseWorkload<FullyConnectedQueueDescriptor>
28 {
29 public:
30     NeonFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info,
31                                std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
32 
33     virtual void Execute() const override;
34 
35 private:
36     std::unique_ptr<arm_compute::IFunction> m_FullyConnectedLayer;
37     mutable std::unique_ptr<arm_compute::Tensor> m_WeightsTensor;
38     mutable std::unique_ptr<arm_compute::Tensor> m_BiasesTensor;
39     TensorInfo m_WeightsTensorInfo;
40     TensorInfo m_BiasesTensorInfo;
41     mutable bool prepared = false;
42 };
43 
44 } //namespace armnn
45 
46