xref: /aosp_15_r20/external/armnn/src/backends/neon/NeonTensorHandleFactory.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <aclCommon/BaseMemoryManager.hpp>
9 #include <armnn/backends/ITensorHandleFactory.hpp>
10 
11 namespace armnn
12 {
13 
NeonTensorHandleFactoryId()14 constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; }
15 
16 const std::set<armnn::LayerType> paddingRequiredLayers {
17     LayerType::ArgMinMax,
18     LayerType::Convolution2d,
19     LayerType::DepthToSpace,
20     LayerType::DepthwiseConvolution2d,
21     LayerType::Dequantize,
22     LayerType::FullyConnected,
23     LayerType::Gather,
24     LayerType::Lstm,
25     LayerType::Mean,
26     LayerType::Permute,
27     LayerType::Pooling2d,
28     LayerType::Quantize,
29     LayerType::QuantizedLstm,
30     LayerType::Stack,
31     LayerType::TransposeConvolution2d
32 };
33 
34 class NeonTensorHandleFactory : public ITensorHandleFactory
35 {
36 public:
NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)37     NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
38                             : m_MemoryManager(mgr),
39                               m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
40                               m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
41     {}
42 
43     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
44                                                          const TensorShape& subTensorShape,
45                                                          const unsigned int* subTensorOrigin) const override;
46 
47     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
48 
49     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50                                                       DataLayout dataLayout) const override;
51 
52     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
53                                                       const bool IsMemoryManaged) const override;
54 
55     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
56                                                       DataLayout dataLayout,
57                                                       const bool IsMemoryManaged = true) const override;
58 
59     static const FactoryId& GetIdStatic();
60 
61     const FactoryId& GetId() const override;
62 
63     bool SupportsInPlaceComputation() const override;
64 
65     bool SupportsSubTensors() const override;
66 
67     MemorySourceFlags GetExportFlags() const override;
68 
69     MemorySourceFlags GetImportFlags() const override;
70 
71     std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
72                                             const IConnectableLayer* connectedLayer,
73                                             CapabilityClass capabilityClass) override;
74 
75 private:
76     mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
77     MemorySourceFlags m_ImportFlags;
78     MemorySourceFlags m_ExportFlags;
79 };
80 
81 } // namespace armnn
82