xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/TosaRefWorkloadFactory.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <Layer.hpp>
6 #include <armnn/backends/MemCopyWorkload.hpp>
7 #include <backendsCommon/MemImportWorkload.hpp>
8 #include <backendsCommon/MakeWorkloadHelper.hpp>
9 #include <armnn/backends/TensorHandle.hpp>
10 #include "TosaRefWorkloadFactory.hpp"
11 #include "TosaRefBackendId.hpp"
12 #include "workloads/TosaRefWorkloads.hpp"
13 #include "TosaRefTensorHandle.hpp"
14 #include "TosaRefWorkloadFactory.hpp"
15 
16 
17 namespace armnn
18 {
19 
20 namespace
21 {
22 static const BackendId s_Id{TosaRefBackendId()};
23 }
24 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
MakeWorkload(const QueueDescriptorType & descriptor,const WorkloadInfo & info) const25 std::unique_ptr<IWorkload> TosaRefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
26                                                             const WorkloadInfo& info) const
27 {
28     return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
29            (descriptor, info);
30 }
31 
32 template <DataType ArmnnType>
IsDataType(const WorkloadInfo & info)33 bool IsDataType(const WorkloadInfo& info)
34 {
35     auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
36     auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
37     if (it != std::end(info.m_InputTensorInfos))
38     {
39         return true;
40     }
41     it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
42     if (it != std::end(info.m_OutputTensorInfos))
43     {
44         return true;
45     }
46     return false;
47 }
48 
TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager> & memoryManager)49 TosaRefWorkloadFactory::TosaRefWorkloadFactory(const std::shared_ptr<TosaRefMemoryManager>& memoryManager)
50     : m_MemoryManager(memoryManager)
51 {
52 }
53 
TosaRefWorkloadFactory()54 TosaRefWorkloadFactory::TosaRefWorkloadFactory()
55     : m_MemoryManager(new TosaRefMemoryManager())
56 {
57 }
58 
GetBackendId() const59 const BackendId& TosaRefWorkloadFactory::GetBackendId() const
60 {
61     return s_Id;
62 }
63 
IsLayerSupported(const Layer & layer,Optional<DataType> dataType,std::string & outReasonIfUnsupported)64 bool TosaRefWorkloadFactory::IsLayerSupported(const Layer& layer,
65                                               Optional<DataType> dataType,
66                                               std::string& outReasonIfUnsupported)
67 {
68     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
69 }
70 
IsLayerSupported(const IConnectableLayer & layer,Optional<DataType> dataType,std::string & outReasonIfUnsupported,const ModelOptions & modelOptions)71 bool TosaRefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
72                                               Optional<DataType> dataType,
73                                               std::string& outReasonIfUnsupported,
74                                               const ModelOptions& modelOptions)
75 {
76     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
77 }
78 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool isMemoryManaged) const79 std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
80                                                                           const bool isMemoryManaged) const
81 {
82     if (isMemoryManaged)
83     {
84         return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
85     }
86     else
87     {
88         return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
89     }
90 }
91 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool isMemoryManaged) const92 std::unique_ptr<ITensorHandle> TosaRefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
93                                                                           DataLayout dataLayout,
94                                                                           const bool isMemoryManaged) const
95 {
96     // For TosaRef it is okay to make the TensorHandle memory managed as it can also store a pointer
97     // to unmanaged memory. This also ensures memory alignment.
98     IgnoreUnused(isMemoryManaged, dataLayout);
99 
100     if (isMemoryManaged)
101     {
102         return std::make_unique<TosaRefTensorHandle>(tensorInfo, m_MemoryManager);
103     }
104     else
105     {
106         return std::make_unique<TosaRefTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
107     }
108 }
109 
CreateWorkload(LayerType type,const QueueDescriptor & descriptor,const WorkloadInfo & info) const110 std::unique_ptr<IWorkload> TosaRefWorkloadFactory::CreateWorkload(LayerType type,
111                                                                   const QueueDescriptor& descriptor,
112                                                                   const WorkloadInfo& info) const
113 {
114     switch(type)
115     {
116         case LayerType::PreCompiled:
117         {
118             auto precompiledQueueDescriptor = PolymorphicDowncast<const PreCompiledQueueDescriptor*>(&descriptor);
119             return std::make_unique<TosaRefPreCompiledWorkload>(*precompiledQueueDescriptor, info);
120         }
121         default:
122             return nullptr;
123     }
124 }
125 
126 } // namespace armnn
127