1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ClTensorHandleFactory.hpp"
7 #include "ClTensorHandle.hpp"
8
9 #include <armnn/utility/NumericCast.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11
12 #include <arm_compute/core/Coordinates.h>
13 #include <arm_compute/runtime/CL/CLSubTensor.h>
14 #include <arm_compute/runtime/CL/CLTensor.h>
15
16 namespace armnn
17 {
18
19 using FactoryId = ITensorHandleFactory::FactoryId;
20
CreateSubTensorHandle(ITensorHandle & parent,const TensorShape & subTensorShape,const unsigned int * subTensorOrigin) const21 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
22 const TensorShape& subTensorShape,
23 const unsigned int* subTensorOrigin) const
24 {
25 arm_compute::Coordinates coords;
26 arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
27
28 coords.set_num_dimensions(subTensorShape.GetNumDimensions());
29 for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
30 {
31 // Arm compute indexes tensor coords in reverse order.
32 unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
33 coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
34 }
35
36 const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
37
38 // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
39 // must match the parent shapes
40 if (coords.x() != 0 || coords.y() != 0)
41 {
42 return nullptr;
43 }
44 if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
45 {
46 return nullptr;
47 }
48
49 if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
50 {
51 return nullptr;
52 }
53
54 return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
55 }
56
CreateTensorHandle(const TensorInfo & tensorInfo) const57 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
58 {
59 return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
60 }
61
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const62 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
63 DataLayout dataLayout) const
64 {
65 return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
66 }
67
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const68 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
69 const bool IsMemoryManaged) const
70 {
71 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
72 if (!IsMemoryManaged)
73 {
74 ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
75 }
76 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
77 return tensorHandle;
78 }
79
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const80 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
81 DataLayout dataLayout,
82 const bool IsMemoryManaged) const
83 {
84 std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
85 if (!IsMemoryManaged)
86 {
87 ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
88 }
89 tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
90 return tensorHandle;
91 }
92
GetIdStatic()93 const FactoryId& ClTensorHandleFactory::GetIdStatic()
94 {
95 static const FactoryId s_Id(ClTensorHandleFactoryId());
96 return s_Id;
97 }
98
GetId() const99 const FactoryId& ClTensorHandleFactory::GetId() const
100 {
101 return GetIdStatic();
102 }
103
SupportsSubTensors() const104 bool ClTensorHandleFactory::SupportsSubTensors() const
105 {
106 return true;
107 }
108
GetExportFlags() const109 MemorySourceFlags ClTensorHandleFactory::GetExportFlags() const
110 {
111 return MemorySourceFlags(MemorySource::Undefined);
112 }
113
GetImportFlags() const114 MemorySourceFlags ClTensorHandleFactory::GetImportFlags() const
115 {
116 return MemorySourceFlags(MemorySource::Undefined);
117 }
118
119 } // namespace armnn
120