1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <aclCommon/BaseMemoryManager.hpp>
9 #include <armnn/backends/ITensorHandleFactory.hpp>
10
11 namespace armnn
12 {
13
NeonTensorHandleFactoryId()14 constexpr const char* NeonTensorHandleFactoryId() { return "Arm/Neon/TensorHandleFactory"; }
15
16 const std::set<armnn::LayerType> paddingRequiredLayers {
17 LayerType::ArgMinMax,
18 LayerType::Convolution2d,
19 LayerType::DepthToSpace,
20 LayerType::DepthwiseConvolution2d,
21 LayerType::Dequantize,
22 LayerType::FullyConnected,
23 LayerType::Gather,
24 LayerType::Lstm,
25 LayerType::Mean,
26 LayerType::Permute,
27 LayerType::Pooling2d,
28 LayerType::Quantize,
29 LayerType::QuantizedLstm,
30 LayerType::Stack,
31 LayerType::TransposeConvolution2d
32 };
33
34 class NeonTensorHandleFactory : public ITensorHandleFactory
35 {
36 public:
NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)37 NeonTensorHandleFactory(std::weak_ptr<NeonMemoryManager> mgr)
38 : m_MemoryManager(mgr),
39 m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)),
40 m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
41 {}
42
43 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
44 const TensorShape& subTensorShape,
45 const unsigned int* subTensorOrigin) const override;
46
47 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
48
49 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50 DataLayout dataLayout) const override;
51
52 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
53 const bool IsMemoryManaged) const override;
54
55 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
56 DataLayout dataLayout,
57 const bool IsMemoryManaged = true) const override;
58
59 static const FactoryId& GetIdStatic();
60
61 const FactoryId& GetId() const override;
62
63 bool SupportsInPlaceComputation() const override;
64
65 bool SupportsSubTensors() const override;
66
67 MemorySourceFlags GetExportFlags() const override;
68
69 MemorySourceFlags GetImportFlags() const override;
70
71 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
72 const IConnectableLayer* connectedLayer,
73 CapabilityClass capabilityClass) override;
74
75 private:
76 mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
77 MemorySourceFlags m_ImportFlags;
78 MemorySourceFlags m_ExportFlags;
79 };
80
81 } // namespace armnn
82