xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/TosaRefWorkloadFactory.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "TosaRefMemoryManager.hpp"
8 
9 #include <armnn/Optional.hpp>
10 #include <armnn/backends/WorkloadFactory.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 
13 
14 namespace armnn
15 {
16 
17 // Reference workload factory.
18 class TosaRefWorkloadFactory : public IWorkloadFactory
19 {
20 public:
21     explicit TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager);
22     TosaRefWorkloadFactory();
23 
~TosaRefWorkloadFactory()24     ~TosaRefWorkloadFactory() {}
25 
26     const BackendId& GetBackendId() const override;
27 
28     static bool IsLayerSupported(const Layer& layer,
29                                  Optional<DataType> dataType,
30                                  std::string& outReasonIfUnsupported);
31 
32     static bool IsLayerSupported(const IConnectableLayer& layer,
33                                  Optional<DataType> dataType,
34                                  std::string& outReasonIfUnsupported,
35                                  const ModelOptions& modelOptions);
36 
SupportsSubTensors() const37     bool SupportsSubTensors() const override { return false; }
38 
39     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const40     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
41                                                          TensorShape const& subTensorShape,
42                                                          unsigned int const* subTensorOrigin) const override
43     {
44         IgnoreUnused(parent, subTensorShape, subTensorOrigin);
45         return nullptr;
46     }
47 
48     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
49     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50                                                       const bool IsMemoryManaged = true) const override;
51 
52     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
53     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
54                                                       DataLayout dataLayout,
55                                                       const bool IsMemoryManaged = true) const override;
56 
57     std::unique_ptr<IWorkload> CreateWorkload(LayerType type,
58                                               const QueueDescriptor& descriptor,
59                                               const WorkloadInfo& info) const override;
60 
61 private:
62     template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
63     std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const;
64 
65     mutable std::shared_ptr<TosaRefMemoryManager> m_MemoryManager;
66 };
67 
68 } // namespace armnn
69