xref: /aosp_15_r20/external/armnn/src/backends/cl/ClImportTensorHandleFactory.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClImportTensorHandleFactory.hpp"
7 #include "ClImportTensorHandle.hpp"
8 
9 #include <armnn/utility/NumericCast.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 #include <arm_compute/core/Coordinates.h>
13 #include <arm_compute/runtime/CL/CLTensor.h>
14 
15 namespace armnn
16 {
17 
18 using FactoryId = ITensorHandleFactory::FactoryId;
19 
CreateSubTensorHandle(ITensorHandle & parent,const TensorShape & subTensorShape,const unsigned int * subTensorOrigin) const20 std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateSubTensorHandle(
21     ITensorHandle& parent, const TensorShape& subTensorShape, const unsigned int* subTensorOrigin) const
22 {
23     arm_compute::Coordinates coords;
24     arm_compute::TensorShape shape = armcomputetensorutils::BuildArmComputeTensorShape(subTensorShape);
25 
26     coords.set_num_dimensions(subTensorShape.GetNumDimensions());
27     for (unsigned int i = 0; i < subTensorShape.GetNumDimensions(); ++i)
28     {
29         // Arm compute indexes tensor coords in reverse order.
30         unsigned int revertedIndex = subTensorShape.GetNumDimensions() - i - 1;
31         coords.set(i, armnn::numeric_cast<int>(subTensorOrigin[revertedIndex]));
32     }
33 
34     const arm_compute::TensorShape parentShape = armcomputetensorutils::BuildArmComputeTensorShape(parent.GetShape());
35 
36     // In order for ACL to support subtensors the concat axis cannot be on x or y and the values of x and y
37     // must match the parent shapes
38     if (coords.x() != 0 || coords.y() != 0)
39     {
40         return nullptr;
41     }
42     if ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
43     {
44         return nullptr;
45     }
46 
47     if (!::arm_compute::error_on_invalid_subtensor(__func__, __FILE__, __LINE__, parentShape, coords, shape))
48     {
49         return nullptr;
50     }
51 
52     return std::make_unique<ClImportSubTensorHandle>(
53         PolymorphicDowncast<IClTensorHandle*>(&parent), shape, coords);
54 }
55 
CreateTensorHandle(const TensorInfo & tensorInfo) const56 std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
57 {
58     std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
59                                                                                                 GetImportFlags());
60     return tensorHandle;
61 }
62 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const63 std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
64                                                                                DataLayout dataLayout) const
65 {
66     std::unique_ptr<ClImportTensorHandle> tensorHandle = std::make_unique<ClImportTensorHandle>(tensorInfo,
67                                                                                                 dataLayout,
68                                                                                                 GetImportFlags());
69     return tensorHandle;
70 }
71 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const72 std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
73                                                                                const bool IsMemoryManaged) const
74 {
75     if (IsMemoryManaged)
76     {
77         throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
78     }
79     return CreateTensorHandle(tensorInfo);
80 }
81 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const82 std::unique_ptr<ITensorHandle> ClImportTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
83                                                                                DataLayout dataLayout,
84                                                                                const bool IsMemoryManaged) const
85 {
86     if (IsMemoryManaged)
87     {
88         throw InvalidArgumentException("ClImportTensorHandleFactory does not support memory managed tensors.");
89     }
90     return CreateTensorHandle(tensorInfo, dataLayout);
91 }
92 
GetIdStatic()93 const FactoryId& ClImportTensorHandleFactory::GetIdStatic()
94 {
95     static const FactoryId s_Id(ClImportTensorHandleFactoryId());
96     return s_Id;
97 }
98 
GetId() const99 const FactoryId& ClImportTensorHandleFactory::GetId() const
100 {
101     return GetIdStatic();
102 }
103 
SupportsSubTensors() const104 bool ClImportTensorHandleFactory::SupportsSubTensors() const
105 {
106     return true;
107 }
108 
SupportsMapUnmap() const109 bool ClImportTensorHandleFactory::SupportsMapUnmap() const
110 {
111     return false;
112 }
113 
GetExportFlags() const114 MemorySourceFlags ClImportTensorHandleFactory::GetExportFlags() const
115 {
116     return m_ExportFlags;
117 }
118 
GetImportFlags() const119 MemorySourceFlags ClImportTensorHandleFactory::GetImportFlags() const
120 {
121     return m_ImportFlags;
122 }
123 
GetCapabilities(const IConnectableLayer * layer,const IConnectableLayer * connectedLayer,CapabilityClass capabilityClass)124 std::vector<Capability> ClImportTensorHandleFactory::GetCapabilities(const IConnectableLayer* layer,
125                                                                      const IConnectableLayer* connectedLayer,
126                                                                      CapabilityClass capabilityClass)
127 {
128     IgnoreUnused(layer);
129     IgnoreUnused(connectedLayer);
130     std::vector<Capability> capabilities;
131     if (capabilityClass == CapabilityClass::FallbackImportDisabled)
132     {
133         Capability paddingCapability(CapabilityClass::FallbackImportDisabled, true);
134         capabilities.push_back(paddingCapability);
135     }
136     return capabilities;
137 }
138 
139 }    // namespace armnn