xref: /aosp_15_r20/external/armnn/src/backends/cl/ClTensorHandleFactory.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClTensorHandleFactory.hpp"
7 #include "ClTensorHandle.hpp"
8 
9 #include <armnn/utility/NumericCast.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 #include <arm_compute/core/Coordinates.h>
13 #include <arm_compute/runtime/CL/CLSubTensor.h>
14 #include <arm_compute/runtime/CL/CLTensor.h>
15 
16 namespace armnn
17 {
18 
19 using FactoryId = ITensorHandleFactory::FactoryId;
20 
CreateSubTensorHandle(ITensorHandle & parent,const TensorShape & subTensorShape,const unsigned int * subTensorOrigin) const21 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateSubTensorHandle(ITensorHandle& parent,
22                                                                             const TensorShape& subTensorShape,
23                                                                             const unsigned int* subTensorOrigin) const
24 {
25     arm_compute::Coordinates coords;
26     arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
27 
28     coords.set_num_dimensions(subTensorShape.GetNumDimensions());
29     for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
30     {
31         // Arm compute indexes tensor coords in reverse order.
32         unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
33         coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
34     }
35 
36     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
37 
38     // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
39     // must match the parent shapes
40     if (coords.x() != 0 || coords.y() != 0)
41     {
42         return nullptr;
43     }
44     if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
45     {
46         return nullptr;
47     }
48 
49     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
50     {
51         return nullptr;
52     }
53 
54     return std::make_unique<ClSubTensorHandle>(PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
55 }
56 
CreateTensorHandle(const TensorInfo & tensorInfo) const57 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
58 {
59     return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, true);
60 }
61 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const62 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
63                                                                          DataLayout dataLayout) const
64 {
65     return ClTensorHandleFactory::CreateTensorHandle(tensorInfo, dataLayout, true);
66 }
67 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const68 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
69                                                                          const bool IsMemoryManaged) const
70 {
71     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo);
72     if (!IsMemoryManaged)
73     {
74         ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
75     }
76     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
77     return tensorHandle;
78 }
79 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const80 std::unique_ptr<ITensorHandle> ClTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
81                                                                          DataLayout dataLayout,
82                                                                          const bool IsMemoryManaged) const
83 {
84     std::unique_ptr<ClTensorHandle> tensorHandle = std::make_unique<ClTensorHandle>(tensorInfo, dataLayout);
85     if (!IsMemoryManaged)
86     {
87         ARMNN_LOG(warning) << "ClTensorHandleFactory only has support for memory managed.";
88     }
89     tensorHandle->SetMemoryGroup(m_MemoryManager->GetInterLayerMemoryGroup());
90     return tensorHandle;
91 }
92 
GetIdStatic()93 const FactoryId& ClTensorHandleFactory::GetIdStatic()
94 {
95     static const FactoryId s_Id(ClTensorHandleFactoryId());
96     return s_Id;
97 }
98 
GetId() const99 const FactoryId& ClTensorHandleFactory::GetId() const
100 {
101     return GetIdStatic();
102 }
103 
SupportsSubTensors() const104 bool ClTensorHandleFactory::SupportsSubTensors() const
105 {
106     return true;
107 }
108 
GetExportFlags() const109 MemorySourceFlags ClTensorHandleFactory::GetExportFlags() const
110 {
111     return MemorySourceFlags(MemorySource::Undefined);
112 }
113 
GetImportFlags() const114 MemorySourceFlags ClTensorHandleFactory::GetImportFlags() const
115 {
116     return MemorySourceFlags(MemorySource::Undefined);
117 }
118 
119 }    // namespace armnn
120