xref: /aosp_15_r20/external/armnn/src/armnn/LoadedNetwork.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "LoadedNetwork.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "Layer.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "Graph.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "Profiling.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "HeapProfiling.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "WorkingMemHandle.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ExecutionData.hpp"
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/MemCopyWorkload.hpp>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/profiling/ArmNNProfiling.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/MemSyncWorkload.hpp>
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker #include <common/include/Processes.hpp>
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker namespace armnn
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker using namespace std;
37*89c4ff92SAndroid Build Coastguard Worker using namespace arm::pipe;
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker namespace
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker template <typename ExceptionType>
ToErrorMessage(const char * prefix,const ExceptionType & error)43*89c4ff92SAndroid Build Coastguard Worker std::string ToErrorMessage(const char * prefix, const ExceptionType & error)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker     std::stringstream ss;
46*89c4ff92SAndroid Build Coastguard Worker     ss << prefix << " " << error.what();
47*89c4ff92SAndroid Build Coastguard Worker     return ss.str();
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker 
AddLayerStructure(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,const Layer & layer,ProfilingGuid networkGuid)50*89c4ff92SAndroid Build Coastguard Worker void AddLayerStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
51*89c4ff92SAndroid Build Coastguard Worker                        const Layer& layer,
52*89c4ff92SAndroid Build Coastguard Worker                        ProfilingGuid networkGuid)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker     // Add layer to the post-optimisation network structure
55*89c4ff92SAndroid Build Coastguard Worker     std::string layerName = layer.GetNameStr().empty() ? "<Unnamed>" : layer.GetNameStr();
56*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->CreateNamedTypedChildEntity(layer.GetGuid(),
57*89c4ff92SAndroid Build Coastguard Worker                                                networkGuid,
58*89c4ff92SAndroid Build Coastguard Worker                                                layerName,
59*89c4ff92SAndroid Build Coastguard Worker                                                LabelsAndEventClasses::LAYER_GUID);
60*89c4ff92SAndroid Build Coastguard Worker     for (auto&& input : layer.GetInputSlots())
61*89c4ff92SAndroid Build Coastguard Worker     {
62*89c4ff92SAndroid Build Coastguard Worker         const IOutputSlot* source = input.GetConnectedOutputSlot();
63*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(source != NULL);
64*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateConnectionRelationship(ProfilingRelationshipType::RetentionLink,
65*89c4ff92SAndroid Build Coastguard Worker                                                     source->GetOwningLayerGuid(),
66*89c4ff92SAndroid Build Coastguard Worker                                                     layer.GetGuid());
67*89c4ff92SAndroid Build Coastguard Worker     }
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker 
AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,std::unique_ptr<IWorkload> & workload,const Layer & layer)70*89c4ff92SAndroid Build Coastguard Worker void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
71*89c4ff92SAndroid Build Coastguard Worker                           std::unique_ptr<IWorkload>& workload,
72*89c4ff92SAndroid Build Coastguard Worker                           const Layer& layer)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker     // Add workload to the post-optimisation network structure
75*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->CreateTypedEntity(workload->GetGuid(), LabelsAndEventClasses::WORKLOAD_GUID);
76*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->MarkEntityWithLabel(workload->GetGuid(),
77*89c4ff92SAndroid Build Coastguard Worker                                        layer.GetBackendId().Get(),
78*89c4ff92SAndroid Build Coastguard Worker                                        LabelsAndEventClasses::BACKENDID_GUID);
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Link the workload to the layer
81*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
82*89c4ff92SAndroid Build Coastguard Worker                                       layer.GetGuid(),
83*89c4ff92SAndroid Build Coastguard Worker                                       workload->GetGuid(),
84*89c4ff92SAndroid Build Coastguard Worker                                       LabelsAndEventClasses::CHILD_GUID);
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker } // anonymous
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker /**
90*89c4ff92SAndroid Build Coastguard Worker  * This function performs a sanity check to ensure that the combination of input and output memory source matches the
91*89c4ff92SAndroid Build Coastguard Worker  * values for importEnabled and exportEnabled that were specified during optimization. During optimization the tensor
92*89c4ff92SAndroid Build Coastguard Worker  * handle factories are chosen based on whether import and export are enabled. If the user then specifies something
93*89c4ff92SAndroid Build Coastguard Worker  * incompatible here it can lead to problems.
94*89c4ff92SAndroid Build Coastguard Worker  *
95*89c4ff92SAndroid Build Coastguard Worker  * @param optimizedOptions
96*89c4ff92SAndroid Build Coastguard Worker  * @param networkProperties
97*89c4ff92SAndroid Build Coastguard Worker  */
ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,const INetworkProperties & networkProperties)98*89c4ff92SAndroid Build Coastguard Worker void ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,
99*89c4ff92SAndroid Build Coastguard Worker                                           const INetworkProperties& networkProperties)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker     // Find the "Global" backend options. During the optimize phase the values of importEnabled and exportEnabled are
102*89c4ff92SAndroid Build Coastguard Worker     // added as backend options.
103*89c4ff92SAndroid Build Coastguard Worker     const vector<BackendOptions>::iterator& backendItr =
104*89c4ff92SAndroid Build Coastguard Worker         find_if(optimizedOptions.begin(), optimizedOptions.end(), [](const BackendOptions& backend) {
105*89c4ff92SAndroid Build Coastguard Worker             if (backend.GetBackendId().Get() == "Global")
106*89c4ff92SAndroid Build Coastguard Worker             {
107*89c4ff92SAndroid Build Coastguard Worker                 return true;
108*89c4ff92SAndroid Build Coastguard Worker             }
109*89c4ff92SAndroid Build Coastguard Worker             else
110*89c4ff92SAndroid Build Coastguard Worker             {
111*89c4ff92SAndroid Build Coastguard Worker                 return false;
112*89c4ff92SAndroid Build Coastguard Worker             }
113*89c4ff92SAndroid Build Coastguard Worker         });
114*89c4ff92SAndroid Build Coastguard Worker     bool importEnabled = false;
115*89c4ff92SAndroid Build Coastguard Worker     bool exportEnabled = false;
116*89c4ff92SAndroid Build Coastguard Worker     if (backendItr != optimizedOptions.end())
117*89c4ff92SAndroid Build Coastguard Worker     {
118*89c4ff92SAndroid Build Coastguard Worker         // Find the importEnabled and exportEnabled values.
119*89c4ff92SAndroid Build Coastguard Worker         for (size_t i = 0; i < backendItr->GetOptionCount(); i++)
120*89c4ff92SAndroid Build Coastguard Worker         {
121*89c4ff92SAndroid Build Coastguard Worker             const BackendOptions::BackendOption& option = backendItr->GetOption(i);
122*89c4ff92SAndroid Build Coastguard Worker             if (option.GetName() == "ImportEnabled")
123*89c4ff92SAndroid Build Coastguard Worker             {
124*89c4ff92SAndroid Build Coastguard Worker                 importEnabled = option.GetValue().AsBool();
125*89c4ff92SAndroid Build Coastguard Worker             }
126*89c4ff92SAndroid Build Coastguard Worker             if (option.GetName() == "ExportEnabled")
127*89c4ff92SAndroid Build Coastguard Worker             {
128*89c4ff92SAndroid Build Coastguard Worker                 exportEnabled = option.GetValue().AsBool();
129*89c4ff92SAndroid Build Coastguard Worker             }
130*89c4ff92SAndroid Build Coastguard Worker         }
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     // Now that we have values for import and export compare them to the MemorySource variables.
134*89c4ff92SAndroid Build Coastguard Worker     // Any value of MemorySource that's not "Undefined" implies that we need to do an import of some kind.
135*89c4ff92SAndroid Build Coastguard Worker     if ((networkProperties.m_InputSource == MemorySource::Undefined && importEnabled) ||
136*89c4ff92SAndroid Build Coastguard Worker         (networkProperties.m_InputSource != MemorySource::Undefined && !importEnabled))
137*89c4ff92SAndroid Build Coastguard Worker     {
138*89c4ff92SAndroid Build Coastguard Worker         auto message = fmt::format("The input memory source specified, '{0}',", networkProperties.m_InputSource);
139*89c4ff92SAndroid Build Coastguard Worker         if (!importEnabled)
140*89c4ff92SAndroid Build Coastguard Worker         {
141*89c4ff92SAndroid Build Coastguard Worker             message.append(" requires that memory import be enabled. However, "
142*89c4ff92SAndroid Build Coastguard Worker                            "it was disabled when this network was optimized.");
143*89c4ff92SAndroid Build Coastguard Worker         }
144*89c4ff92SAndroid Build Coastguard Worker         else
145*89c4ff92SAndroid Build Coastguard Worker         {
146*89c4ff92SAndroid Build Coastguard Worker             message.append(" requires that memory import be disabled. However, "
147*89c4ff92SAndroid Build Coastguard Worker                            "it was enabled when this network was optimized.");
148*89c4ff92SAndroid Build Coastguard Worker         }
149*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(message);
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     if ((networkProperties.m_OutputSource == MemorySource::Undefined && exportEnabled) ||
153*89c4ff92SAndroid Build Coastguard Worker         (networkProperties.m_OutputSource != MemorySource::Undefined && !exportEnabled))
154*89c4ff92SAndroid Build Coastguard Worker     {
155*89c4ff92SAndroid Build Coastguard Worker         auto message = fmt::format("The output memory source specified, '{0}',", networkProperties.m_OutputSource);
156*89c4ff92SAndroid Build Coastguard Worker         if (!exportEnabled)
157*89c4ff92SAndroid Build Coastguard Worker         {
158*89c4ff92SAndroid Build Coastguard Worker             message.append(" requires that memory export be enabled. However, "
159*89c4ff92SAndroid Build Coastguard Worker                            "it was disabled when this network was optimized.");
160*89c4ff92SAndroid Build Coastguard Worker         }
161*89c4ff92SAndroid Build Coastguard Worker         else
162*89c4ff92SAndroid Build Coastguard Worker         {
163*89c4ff92SAndroid Build Coastguard Worker             message.append(" requires that memory export be disabled. However, "
164*89c4ff92SAndroid Build Coastguard Worker                            "it was enabled when this network was optimized.");
165*89c4ff92SAndroid Build Coastguard Worker         }
166*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(message);
167*89c4ff92SAndroid Build Coastguard Worker     }
168*89c4ff92SAndroid Build Coastguard Worker } // anonymous
169*89c4ff92SAndroid Build Coastguard Worker 
MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,std::string & errorMessage,const INetworkProperties & networkProperties,arm::pipe::IProfilingService * profilingService)170*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
171*89c4ff92SAndroid Build Coastguard Worker                                                                 std::string& errorMessage,
172*89c4ff92SAndroid Build Coastguard Worker                                                                 const INetworkProperties& networkProperties,
173*89c4ff92SAndroid Build Coastguard Worker                                                                 arm::pipe::IProfilingService* profilingService)
174*89c4ff92SAndroid Build Coastguard Worker {
175*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<LoadedNetwork> loadedNetwork;
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     auto Fail = [&](const std::exception& error) -> std::unique_ptr<LoadedNetwork>
178*89c4ff92SAndroid Build Coastguard Worker     {
179*89c4ff92SAndroid Build Coastguard Worker         errorMessage = ToErrorMessage("An error occurred when preparing the network workloads: ", error);
180*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << errorMessage;
181*89c4ff92SAndroid Build Coastguard Worker 
182*89c4ff92SAndroid Build Coastguard Worker         return std::unique_ptr<LoadedNetwork>();
183*89c4ff92SAndroid Build Coastguard Worker     };
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker     try
186*89c4ff92SAndroid Build Coastguard Worker     {
187*89c4ff92SAndroid Build Coastguard Worker         loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
188*89c4ff92SAndroid Build Coastguard Worker     }
189*89c4ff92SAndroid Build Coastguard Worker     catch (const armnn::RuntimeException& error)
190*89c4ff92SAndroid Build Coastguard Worker     {
191*89c4ff92SAndroid Build Coastguard Worker         return Fail(error);
192*89c4ff92SAndroid Build Coastguard Worker     }
193*89c4ff92SAndroid Build Coastguard Worker     catch (const armnn::Exception& error)
194*89c4ff92SAndroid Build Coastguard Worker     {
195*89c4ff92SAndroid Build Coastguard Worker         return Fail(error);
196*89c4ff92SAndroid Build Coastguard Worker     }
197*89c4ff92SAndroid Build Coastguard Worker     catch (const std::runtime_error& error)
198*89c4ff92SAndroid Build Coastguard Worker     {
199*89c4ff92SAndroid Build Coastguard Worker         return Fail(error);
200*89c4ff92SAndroid Build Coastguard Worker     }
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     return loadedNetwork;
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker 
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,const INetworkProperties & networkProperties,arm::pipe::IProfilingService * profilingService)205*89c4ff92SAndroid Build Coastguard Worker LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
206*89c4ff92SAndroid Build Coastguard Worker                              const INetworkProperties& networkProperties,
207*89c4ff92SAndroid Build Coastguard Worker                              arm::pipe::IProfilingService* profilingService) :
208*89c4ff92SAndroid Build Coastguard Worker                              m_OptimizedNetwork(std::move(net)),
209*89c4ff92SAndroid Build Coastguard Worker                              m_NetworkProperties(networkProperties),
210*89c4ff92SAndroid Build Coastguard Worker                              m_TensorHandleFactoryRegistry(),
211*89c4ff92SAndroid Build Coastguard Worker                              m_ProfilingService(profilingService)
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadedNetwork");
214*89c4ff92SAndroid Build Coastguard Worker     // Get the profiler and register it for the current thread.
215*89c4ff92SAndroid Build Coastguard Worker     const std::shared_ptr<IProfiler>& profiler = m_OptimizedNetwork->GetProfiler();
216*89c4ff92SAndroid Build Coastguard Worker     ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
217*89c4ff92SAndroid Build Coastguard Worker 
218*89c4ff92SAndroid Build Coastguard Worker     profiler->EnableProfiling(networkProperties.m_ProfilingEnabled);
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker     profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod);
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker     // We need to check that the memory sources match up with the values of import and export specified during the
223*89c4ff92SAndroid Build Coastguard Worker     // optimize phase. If they don't this will throw an exception.
224*89c4ff92SAndroid Build Coastguard Worker     ValidateSourcesMatchOptimizedNetwork(m_OptimizedNetwork.get()->pOptimizedNetworkImpl->GetModelOptions(),
225*89c4ff92SAndroid Build Coastguard Worker                                          m_NetworkProperties);
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker     //First create tensor handlers, backends and workload factories.
228*89c4ff92SAndroid Build Coastguard Worker     //Handlers are created before workloads are.
229*89c4ff92SAndroid Build Coastguard Worker     //Because workload creation can modify some of the handlers,
230*89c4ff92SAndroid Build Coastguard Worker     //(for example the splitter and concat layers).
231*89c4ff92SAndroid Build Coastguard Worker 
232*89c4ff92SAndroid Build Coastguard Worker     bool useExternalMemoryManager = false;
233*89c4ff92SAndroid Build Coastguard Worker     bool useInternalMemoryManager = false;
234*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
235*89c4ff92SAndroid Build Coastguard Worker     // Ensure Topological order
236*89c4ff92SAndroid Build Coastguard Worker     order.SetLayersOutOfOrder();
237*89c4ff92SAndroid Build Coastguard Worker     order.TopologicalSort();
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker     if (!networkProperties.m_AsyncEnabled)
240*89c4ff92SAndroid Build Coastguard Worker     {
241*89c4ff92SAndroid Build Coastguard Worker         m_IsInputImported = std::vector<bool>(order.GetNumInputs(), false);
242*89c4ff92SAndroid Build Coastguard Worker         m_IsOutputImported = std::vector<bool>(order.GetNumOutputs(), false);
243*89c4ff92SAndroid Build Coastguard Worker     }
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
246*89c4ff92SAndroid Build Coastguard Worker     {
247*89c4ff92SAndroid Build Coastguard Worker         auto const& backendId = layer->GetBackendId();
248*89c4ff92SAndroid Build Coastguard Worker         if (m_Backends.count(backendId) == 0)
249*89c4ff92SAndroid Build Coastguard Worker         {
250*89c4ff92SAndroid Build Coastguard Worker             auto createBackend = BackendRegistryInstance().GetFactory(backendId);
251*89c4ff92SAndroid Build Coastguard Worker             auto it = m_Backends.emplace(std::make_pair(backendId, createBackend()));
252*89c4ff92SAndroid Build Coastguard Worker 
253*89c4ff92SAndroid Build Coastguard Worker             IBackendInternal* backend = it.first->second.get();
254*89c4ff92SAndroid Build Coastguard Worker 
255*89c4ff92SAndroid Build Coastguard Worker             // If we're doing async execution verify that the backend supports it and ExternallyManagedMemory.
256*89c4ff92SAndroid Build Coastguard Worker             if (networkProperties.m_AsyncEnabled)
257*89c4ff92SAndroid Build Coastguard Worker             {
258*89c4ff92SAndroid Build Coastguard Worker                 if (!HasCapability(BackendOptions::BackendOption{"AsyncExecution", true}, backend->GetCapabilities()))
259*89c4ff92SAndroid Build Coastguard Worker                 {
260*89c4ff92SAndroid Build Coastguard Worker                     std::string er = backend->GetId();
261*89c4ff92SAndroid Build Coastguard Worker                     er += " does not support AsyncExecution";
262*89c4ff92SAndroid Build Coastguard Worker                     throw BackendCapabilityException(er);
263*89c4ff92SAndroid Build Coastguard Worker                 }
264*89c4ff92SAndroid Build Coastguard Worker                 if (!HasCapability(BackendOptions::BackendOption{"ExternallyManagedMemory", true},
265*89c4ff92SAndroid Build Coastguard Worker                 backend->GetCapabilities()))
266*89c4ff92SAndroid Build Coastguard Worker                 {
267*89c4ff92SAndroid Build Coastguard Worker                     std::string er = backend->GetId();
268*89c4ff92SAndroid Build Coastguard Worker                     er += " does not support ExternallyManagedMemory\n";
269*89c4ff92SAndroid Build Coastguard Worker                     er += "AsyncEnabled networks require all backends to support ExternallyManagedMemory";
270*89c4ff92SAndroid Build Coastguard Worker                     throw BackendCapabilityException(er);
271*89c4ff92SAndroid Build Coastguard Worker                 }
272*89c4ff92SAndroid Build Coastguard Worker                 m_SupportsExternallyManagedMemory[backend->GetId()] = true;
273*89c4ff92SAndroid Build Coastguard Worker                 useExternalMemoryManager = true;
274*89c4ff92SAndroid Build Coastguard Worker             }
275*89c4ff92SAndroid Build Coastguard Worker             else
276*89c4ff92SAndroid Build Coastguard Worker             {
277*89c4ff92SAndroid Build Coastguard Worker                 m_SupportsExternallyManagedMemory[backend->GetId()] = false;
278*89c4ff92SAndroid Build Coastguard Worker                 useInternalMemoryManager = true;
279*89c4ff92SAndroid Build Coastguard Worker             }
280*89c4ff92SAndroid Build Coastguard Worker 
281*89c4ff92SAndroid Build Coastguard Worker             IBackendInternal::IWorkloadFactoryPtr workloadFactory;
282*89c4ff92SAndroid Build Coastguard Worker             if (backend->SupportsTensorAllocatorAPI())
283*89c4ff92SAndroid Build Coastguard Worker             {
284*89c4ff92SAndroid Build Coastguard Worker                 workloadFactory = backend->CreateWorkloadFactory(
285*89c4ff92SAndroid Build Coastguard Worker                     m_TensorHandleFactoryRegistry,
286*89c4ff92SAndroid Build Coastguard Worker                     m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions(),
287*89c4ff92SAndroid Build Coastguard Worker                     static_cast<MemorySourceFlags>(m_NetworkProperties.m_InputSource),
288*89c4ff92SAndroid Build Coastguard Worker                     static_cast<MemorySourceFlags>(m_NetworkProperties.m_OutputSource));
289*89c4ff92SAndroid Build Coastguard Worker             }
290*89c4ff92SAndroid Build Coastguard Worker             else
291*89c4ff92SAndroid Build Coastguard Worker             {
292*89c4ff92SAndroid Build Coastguard Worker                 m_BackendMemoryMangers.emplace_back(backend->CreateMemoryManager());
293*89c4ff92SAndroid Build Coastguard Worker                 workloadFactory = backend->CreateWorkloadFactory(
294*89c4ff92SAndroid Build Coastguard Worker                         m_BackendMemoryMangers.back(), m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions());
295*89c4ff92SAndroid Build Coastguard Worker             }
296*89c4ff92SAndroid Build Coastguard Worker             m_WorkloadFactories[backendId ] = std::move(workloadFactory);
297*89c4ff92SAndroid Build Coastguard Worker         }
298*89c4ff92SAndroid Build Coastguard Worker     }
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker     if (!networkProperties.m_AsyncEnabled)
301*89c4ff92SAndroid Build Coastguard Worker     {
302*89c4ff92SAndroid Build Coastguard Worker         for (auto&& layer : order)
303*89c4ff92SAndroid Build Coastguard Worker         {
304*89c4ff92SAndroid Build Coastguard Worker             auto& workloadFactory = GetWorkloadFactory(*layer);
305*89c4ff92SAndroid Build Coastguard Worker             bool supportsExternalManager = m_SupportsExternallyManagedMemory[layer->GetBackendId()];
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
308*89c4ff92SAndroid Build Coastguard Worker             {
309*89c4ff92SAndroid Build Coastguard Worker                 case LayerType::Input:
310*89c4ff92SAndroid Build Coastguard Worker                 case LayerType::MemImport:
311*89c4ff92SAndroid Build Coastguard Worker                 {
312*89c4ff92SAndroid Build Coastguard Worker                     // If IsImportEnabled is true then we need to set IsMemoryManaged
313*89c4ff92SAndroid Build Coastguard Worker                     // to false when creating TensorHandles
314*89c4ff92SAndroid Build Coastguard Worker                     layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
315*89c4ff92SAndroid Build Coastguard Worker                                                workloadFactory,
316*89c4ff92SAndroid Build Coastguard Worker                                                !supportsExternalManager && !m_NetworkProperties.m_ImportEnabled);
317*89c4ff92SAndroid Build Coastguard Worker                     break;
318*89c4ff92SAndroid Build Coastguard Worker                 }
319*89c4ff92SAndroid Build Coastguard Worker                 case LayerType::Constant:
320*89c4ff92SAndroid Build Coastguard Worker                 {
321*89c4ff92SAndroid Build Coastguard Worker                     layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory, true);
322*89c4ff92SAndroid Build Coastguard Worker                     break;
323*89c4ff92SAndroid Build Coastguard Worker                 }
324*89c4ff92SAndroid Build Coastguard Worker                 default:
325*89c4ff92SAndroid Build Coastguard Worker                 {
326*89c4ff92SAndroid Build Coastguard Worker                     // Look for a layer with 1 OutputSlot which has 1 connection and that connection is an Output Layer
327*89c4ff92SAndroid Build Coastguard Worker                     // If Export is enabled disable memory management so we can export, otherwise we do a copy
328*89c4ff92SAndroid Build Coastguard Worker                     if ((layer->GetNumOutputSlots() == 1) &&
329*89c4ff92SAndroid Build Coastguard Worker                        (layer->GetOutputSlots()[0].GetNumConnections() == 1) &&
330*89c4ff92SAndroid Build Coastguard Worker                        (layer->GetOutputSlots()[0].GetConnection(0)->GetOwningLayer().GetType() == LayerType::Output))
331*89c4ff92SAndroid Build Coastguard Worker                     {
332*89c4ff92SAndroid Build Coastguard Worker                         layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
333*89c4ff92SAndroid Build Coastguard Worker                                                    workloadFactory,
334*89c4ff92SAndroid Build Coastguard Worker                                                    !supportsExternalManager && !m_NetworkProperties.m_ExportEnabled);
335*89c4ff92SAndroid Build Coastguard Worker                     }
336*89c4ff92SAndroid Build Coastguard Worker                     else
337*89c4ff92SAndroid Build Coastguard Worker                     {
338*89c4ff92SAndroid Build Coastguard Worker                         layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
339*89c4ff92SAndroid Build Coastguard Worker                                                    workloadFactory,
340*89c4ff92SAndroid Build Coastguard Worker                                                    !supportsExternalManager);
341*89c4ff92SAndroid Build Coastguard Worker                     }
342*89c4ff92SAndroid Build Coastguard Worker                 }
343*89c4ff92SAndroid Build Coastguard Worker             }
344*89c4ff92SAndroid Build Coastguard Worker         }
345*89c4ff92SAndroid Build Coastguard Worker     }
346*89c4ff92SAndroid Build Coastguard Worker 
347*89c4ff92SAndroid Build Coastguard Worker     ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
348*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TimelineUtilityMethods> timelineUtils =
349*89c4ff92SAndroid Build Coastguard Worker         TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
350*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
351*89c4ff92SAndroid Build Coastguard Worker     {
352*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID);
353*89c4ff92SAndroid Build Coastguard Worker         // Mark the network with a start of life event
354*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->RecordEvent(networkGuid, LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
355*89c4ff92SAndroid Build Coastguard Worker         // and with the process ID
356*89c4ff92SAndroid Build Coastguard Worker         int processID = arm::pipe::GetCurrentProcessId();
357*89c4ff92SAndroid Build Coastguard Worker         std::stringstream ss;
358*89c4ff92SAndroid Build Coastguard Worker         ss << processID;
359*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->MarkEntityWithLabel(networkGuid, ss.str(), LabelsAndEventClasses::PROCESS_ID_GUID);
360*89c4ff92SAndroid Build Coastguard Worker     }
361*89c4ff92SAndroid Build Coastguard Worker 
362*89c4ff92SAndroid Build Coastguard Worker     std::vector<IWorkload*> ConstWorkloads;
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker     //Then create workloads.
365*89c4ff92SAndroid Build Coastguard Worker     {
366*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_CreateWorkloads");
367*89c4ff92SAndroid Build Coastguard Worker         for (auto&& layer: order)
368*89c4ff92SAndroid Build Coastguard Worker         {
369*89c4ff92SAndroid Build Coastguard Worker             if (timelineUtils)
370*89c4ff92SAndroid Build Coastguard Worker             {
371*89c4ff92SAndroid Build Coastguard Worker                 // Add layer to the post-optimisation network structure
372*89c4ff92SAndroid Build Coastguard Worker                 AddLayerStructure(timelineUtils, *layer, networkGuid);
373*89c4ff92SAndroid Build Coastguard Worker             }
374*89c4ff92SAndroid Build Coastguard Worker 
375*89c4ff92SAndroid Build Coastguard Worker             const IWorkloadFactory& workloadFactory = GetWorkloadFactory(*layer);
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
378*89c4ff92SAndroid Build Coastguard Worker             {
379*89c4ff92SAndroid Build Coastguard Worker                 case LayerType::Input:
380*89c4ff92SAndroid Build Coastguard Worker                 case LayerType::Output:
381*89c4ff92SAndroid Build Coastguard Worker                 {
382*89c4ff92SAndroid Build Coastguard Worker                     // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput().
383*89c4ff92SAndroid Build Coastguard Worker                     break;
384*89c4ff92SAndroid Build Coastguard Worker                 }
385*89c4ff92SAndroid Build Coastguard Worker                 default:
386*89c4ff92SAndroid Build Coastguard Worker                 {
387*89c4ff92SAndroid Build Coastguard Worker                     auto workload = layer->CreateWorkload(workloadFactory);
388*89c4ff92SAndroid Build Coastguard Worker 
389*89c4ff92SAndroid Build Coastguard Worker                     if (!workload)
390*89c4ff92SAndroid Build Coastguard Worker                     {
391*89c4ff92SAndroid Build Coastguard Worker                         const char* const layerName =
392*89c4ff92SAndroid Build Coastguard Worker                                 layer->GetNameStr().length() != 0 ? layer->GetName() : "<Unnamed>";
393*89c4ff92SAndroid Build Coastguard Worker                         throw InvalidArgumentException(
394*89c4ff92SAndroid Build Coastguard Worker                                 fmt::format("No workload created for layer (name: '{0}' type: '{1}') (compute '{2}')",
395*89c4ff92SAndroid Build Coastguard Worker                                             layerName, static_cast<int>(layer->GetType()), layer->GetBackendId().Get()
396*89c4ff92SAndroid Build Coastguard Worker                                 ));
397*89c4ff92SAndroid Build Coastguard Worker                     }
398*89c4ff92SAndroid Build Coastguard Worker 
399*89c4ff92SAndroid Build Coastguard Worker                     if (timelineUtils)
400*89c4ff92SAndroid Build Coastguard Worker                     {
401*89c4ff92SAndroid Build Coastguard Worker                         // Add workload to the post-optimisation network structure
402*89c4ff92SAndroid Build Coastguard Worker                         AddWorkloadStructure(timelineUtils, workload, *layer);
403*89c4ff92SAndroid Build Coastguard Worker                     }
404*89c4ff92SAndroid Build Coastguard Worker 
405*89c4ff92SAndroid Build Coastguard Worker                     // For async networks ConstantWorkloads are managed exclusively by LoadedNetwork
406*89c4ff92SAndroid Build Coastguard Worker                     // and are separated out from the other workloads
407*89c4ff92SAndroid Build Coastguard Worker                     if((networkProperties.m_AsyncEnabled  || useExternalMemoryManager) &&
408*89c4ff92SAndroid Build Coastguard Worker                         layer->GetType() == LayerType::Constant)
409*89c4ff92SAndroid Build Coastguard Worker                     {
410*89c4ff92SAndroid Build Coastguard Worker                         m_ConstantTensorHandles[layer->GetGuid()] =
411*89c4ff92SAndroid Build Coastguard Worker                                 layer->GetOutputSlot(0).GetOutputHandler().GetData();
412*89c4ff92SAndroid Build Coastguard Worker                         m_ConstantWorkloads[layer->GetGuid()] = std::move(workload);
413*89c4ff92SAndroid Build Coastguard Worker                     }
414*89c4ff92SAndroid Build Coastguard Worker                     else
415*89c4ff92SAndroid Build Coastguard Worker                     {
416*89c4ff92SAndroid Build Coastguard Worker                         m_WorkloadQueue.push_back(std::move(workload));
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker                         if (layer->GetType() == LayerType::Constant)
419*89c4ff92SAndroid Build Coastguard Worker                         {
420*89c4ff92SAndroid Build Coastguard Worker                             // Place the Constant Workloads into a queue so that they can be executed first
421*89c4ff92SAndroid Build Coastguard Worker                             ConstWorkloads.push_back(m_WorkloadQueue.back().get());
422*89c4ff92SAndroid Build Coastguard Worker                         }
423*89c4ff92SAndroid Build Coastguard Worker                     }
424*89c4ff92SAndroid Build Coastguard Worker                     // release the constant data in the layer.
425*89c4ff92SAndroid Build Coastguard Worker                     layer->ReleaseConstantData();
426*89c4ff92SAndroid Build Coastguard Worker                     break;
427*89c4ff92SAndroid Build Coastguard Worker                 }
428*89c4ff92SAndroid Build Coastguard Worker             }
429*89c4ff92SAndroid Build Coastguard Worker         }
430*89c4ff92SAndroid Build Coastguard Worker     }
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker     // Gather information about workloads for inputs & outputs
433*89c4ff92SAndroid Build Coastguard Worker     if (!networkProperties.m_AsyncEnabled && m_WorkloadQueue.size() != 0)
434*89c4ff92SAndroid Build Coastguard Worker     {
435*89c4ff92SAndroid Build Coastguard Worker         const int noOfInputs = armnn::numeric_cast<int>(order.GetNumInputs());
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker         // Get indices of all workloads connected to each input and
438*89c4ff92SAndroid Build Coastguard Worker         // check if they support tensor handle replacement
439*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* layer: order.GetInputLayers())
440*89c4ff92SAndroid Build Coastguard Worker         {
441*89c4ff92SAndroid Build Coastguard Worker             const auto bindingId = layer->GetBindingId();
442*89c4ff92SAndroid Build Coastguard Worker 
443*89c4ff92SAndroid Build Coastguard Worker             bool supportsReplacement = true;
444*89c4ff92SAndroid Build Coastguard Worker 
445*89c4ff92SAndroid Build Coastguard Worker             for (const auto inputSlot: layer->GetOutputSlot(0).GetConnections())
446*89c4ff92SAndroid Build Coastguard Worker             {
447*89c4ff92SAndroid Build Coastguard Worker                 auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(inputSlot->GetOwningLayer()));
448*89c4ff92SAndroid Build Coastguard Worker                 workloadIndex -= noOfInputs;
449*89c4ff92SAndroid Build Coastguard Worker 
450*89c4ff92SAndroid Build Coastguard Worker                 m_InputWorkloadSlotPairs[bindingId].emplace_back(WorkloadIndices{
451*89c4ff92SAndroid Build Coastguard Worker                         armnn::numeric_cast<unsigned int>(workloadIndex), inputSlot->GetSlotIndex()});
452*89c4ff92SAndroid Build Coastguard Worker 
453*89c4ff92SAndroid Build Coastguard Worker                 auto workload = m_WorkloadQueue[m_InputWorkloadSlotPairs[bindingId].back().m_WorkloadIndex].get();
454*89c4ff92SAndroid Build Coastguard Worker                 supportsReplacement &= workload->SupportsTensorHandleReplacement();
455*89c4ff92SAndroid Build Coastguard Worker             }
456*89c4ff92SAndroid Build Coastguard Worker 
457*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory::FactoryId factoryId = layer->GetOutputSlot(0).GetTensorHandleFactoryId();
458*89c4ff92SAndroid Build Coastguard Worker             // Get matching import factory Id
459*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory::FactoryId importFactoryId =
460*89c4ff92SAndroid Build Coastguard Worker                     m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
461*89c4ff92SAndroid Build Coastguard Worker 
462*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
463*89c4ff92SAndroid Build Coastguard Worker 
464*89c4ff92SAndroid Build Coastguard Worker             if (supportsReplacement && importFactory)
465*89c4ff92SAndroid Build Coastguard Worker             {
466*89c4ff92SAndroid Build Coastguard Worker                 m_PreImportedInputHandles.emplace_back(
467*89c4ff92SAndroid Build Coastguard Worker                         bindingId, importFactory->CreateTensorHandle(layer->GetOutputSlot(0).GetTensorInfo(), false));
468*89c4ff92SAndroid Build Coastguard Worker             }
469*89c4ff92SAndroid Build Coastguard Worker             else
470*89c4ff92SAndroid Build Coastguard Worker             {
471*89c4ff92SAndroid Build Coastguard Worker                 m_PreImportedInputHandles.emplace_back(bindingId, nullptr);
472*89c4ff92SAndroid Build Coastguard Worker             }
473*89c4ff92SAndroid Build Coastguard Worker         }
474*89c4ff92SAndroid Build Coastguard Worker 
475*89c4ff92SAndroid Build Coastguard Worker         // Get indices of all workloads connected to each output and
476*89c4ff92SAndroid Build Coastguard Worker         // check if they support tensor handle replacement
477*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* layer: order.GetOutputLayers())
478*89c4ff92SAndroid Build Coastguard Worker         {
479*89c4ff92SAndroid Build Coastguard Worker             const auto bindingId = layer->GetBindingId();
480*89c4ff92SAndroid Build Coastguard Worker 
481*89c4ff92SAndroid Build Coastguard Worker             const auto outputSlot = layer->GetInputSlot(0).GetConnectedOutputSlot();
482*89c4ff92SAndroid Build Coastguard Worker             auto& indices = m_OutputWorkloadSlotPairs[bindingId];
483*89c4ff92SAndroid Build Coastguard Worker 
484*89c4ff92SAndroid Build Coastguard Worker             auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(outputSlot->GetOwningLayer()));
485*89c4ff92SAndroid Build Coastguard Worker             workloadIndex -= noOfInputs;
486*89c4ff92SAndroid Build Coastguard Worker 
487*89c4ff92SAndroid Build Coastguard Worker             indices.m_OutputSlotIndices = WorkloadIndices{numeric_cast<unsigned int>(workloadIndex),
488*89c4ff92SAndroid Build Coastguard Worker                                                           outputSlot->CalculateIndexOnOwner()};
489*89c4ff92SAndroid Build Coastguard Worker 
490*89c4ff92SAndroid Build Coastguard Worker             bool supportsReplacement = true;
491*89c4ff92SAndroid Build Coastguard Worker             auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
492*89c4ff92SAndroid Build Coastguard Worker             supportsReplacement &= outputWorkload->SupportsTensorHandleReplacement();
493*89c4ff92SAndroid Build Coastguard Worker 
494*89c4ff92SAndroid Build Coastguard Worker             for (auto &inputSlot: outputSlot->GetConnections())
495*89c4ff92SAndroid Build Coastguard Worker             {
496*89c4ff92SAndroid Build Coastguard Worker                 if(inputSlot->GetOwningLayer().GetType() != LayerType::Output)
497*89c4ff92SAndroid Build Coastguard Worker                 {
498*89c4ff92SAndroid Build Coastguard Worker                     auto inWorkloadIndex = std::distance(order.begin(),
499*89c4ff92SAndroid Build Coastguard Worker                                                          order.GetPosInGraph(inputSlot->GetOwningLayer()));
500*89c4ff92SAndroid Build Coastguard Worker                     inWorkloadIndex -= noOfInputs;
501*89c4ff92SAndroid Build Coastguard Worker                     indices.m_InputSlotIndices.emplace_back(WorkloadIndices{numeric_cast<unsigned int>(inWorkloadIndex),
502*89c4ff92SAndroid Build Coastguard Worker                                                             inputSlot->GetSlotIndex()});
503*89c4ff92SAndroid Build Coastguard Worker                     auto inputWorkload = m_WorkloadQueue[indices.m_InputSlotIndices.back().m_WorkloadIndex].get();
504*89c4ff92SAndroid Build Coastguard Worker                     supportsReplacement &= inputWorkload->SupportsTensorHandleReplacement();
505*89c4ff92SAndroid Build Coastguard Worker                 }
506*89c4ff92SAndroid Build Coastguard Worker             }
507*89c4ff92SAndroid Build Coastguard Worker 
508*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId();
509*89c4ff92SAndroid Build Coastguard Worker             // Get matching import factory Id
510*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory::FactoryId importFactoryId =
511*89c4ff92SAndroid Build Coastguard Worker                     m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
512*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker             if (supportsReplacement && importFactory)
515*89c4ff92SAndroid Build Coastguard Worker             {
516*89c4ff92SAndroid Build Coastguard Worker                 m_PreImportedOutputHandles.emplace_back(
517*89c4ff92SAndroid Build Coastguard Worker                         bindingId, importFactory->CreateTensorHandle(outputSlot->GetTensorInfo(), false));
518*89c4ff92SAndroid Build Coastguard Worker             }
519*89c4ff92SAndroid Build Coastguard Worker             else
520*89c4ff92SAndroid Build Coastguard Worker             {
521*89c4ff92SAndroid Build Coastguard Worker                 m_PreImportedOutputHandles.emplace_back(bindingId, nullptr);
522*89c4ff92SAndroid Build Coastguard Worker             }
523*89c4ff92SAndroid Build Coastguard Worker         }
524*89c4ff92SAndroid Build Coastguard Worker     }
525*89c4ff92SAndroid Build Coastguard Worker 
526*89c4ff92SAndroid Build Coastguard Worker     for (auto&& workloadFactory : m_WorkloadFactories)
527*89c4ff92SAndroid Build Coastguard Worker     {
528*89c4ff92SAndroid Build Coastguard Worker         workloadFactory.second->AfterWorkloadsCreated();
529*89c4ff92SAndroid Build Coastguard Worker     }
530*89c4ff92SAndroid Build Coastguard Worker 
531*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
532*89c4ff92SAndroid Build Coastguard Worker     {
533*89c4ff92SAndroid Build Coastguard Worker         // Commit to send the post-optimisation network structure
534*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->Commit();
535*89c4ff92SAndroid Build Coastguard Worker     }
536*89c4ff92SAndroid Build Coastguard Worker 
537*89c4ff92SAndroid Build Coastguard Worker     if (useExternalMemoryManager)
538*89c4ff92SAndroid Build Coastguard Worker     {
539*89c4ff92SAndroid Build Coastguard Worker         if (networkProperties.m_AsyncEnabled)
540*89c4ff92SAndroid Build Coastguard Worker         {
541*89c4ff92SAndroid Build Coastguard Worker             CreateMemoryProfileAsync();
542*89c4ff92SAndroid Build Coastguard Worker         }
543*89c4ff92SAndroid Build Coastguard Worker         else
544*89c4ff92SAndroid Build Coastguard Worker         {
545*89c4ff92SAndroid Build Coastguard Worker             CreateMemoryProfile();
546*89c4ff92SAndroid Build Coastguard Worker         }
547*89c4ff92SAndroid Build Coastguard Worker 
548*89c4ff92SAndroid Build Coastguard Worker         auto backendStrategyMap = BackendRegistryInstance().GetMemoryOptimizerStrategies();
549*89c4ff92SAndroid Build Coastguard Worker         for (auto& backendMemoryProfile : m_MemBlockMap)
550*89c4ff92SAndroid Build Coastguard Worker         {
551*89c4ff92SAndroid Build Coastguard Worker             const BackendId& backendId = backendMemoryProfile.first;
552*89c4ff92SAndroid Build Coastguard Worker             if (backendStrategyMap.find(backendId) != backendStrategyMap.end())
553*89c4ff92SAndroid Build Coastguard Worker             {
554*89c4ff92SAndroid Build Coastguard Worker                 m_MemBinMap[backendId] = backendStrategyMap[backendId]->Optimize(backendMemoryProfile.second);
555*89c4ff92SAndroid Build Coastguard Worker             }
556*89c4ff92SAndroid Build Coastguard Worker             else
557*89c4ff92SAndroid Build Coastguard Worker             {
558*89c4ff92SAndroid Build Coastguard Worker                 m_MemBinMap[backendId] = m_ConstantStrategy->Optimize(backendMemoryProfile.second);
559*89c4ff92SAndroid Build Coastguard Worker             }
560*89c4ff92SAndroid Build Coastguard Worker         }
561*89c4ff92SAndroid Build Coastguard Worker 
562*89c4ff92SAndroid Build Coastguard Worker         if (!networkProperties.m_AsyncEnabled)
563*89c4ff92SAndroid Build Coastguard Worker         {
564*89c4ff92SAndroid Build Coastguard Worker             m_ExternalMemoryManager = CreateExternalMemoryManger(m_TensorMemory);
565*89c4ff92SAndroid Build Coastguard Worker 
566*89c4ff92SAndroid Build Coastguard Worker             // Sort m_TensorMemory, so it's order matches m_Tensorhandles
567*89c4ff92SAndroid Build Coastguard Worker             std::sort(m_TensorMemory.begin(), m_TensorMemory.end(),
568*89c4ff92SAndroid Build Coastguard Worker                       [](const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& lhs,
569*89c4ff92SAndroid Build Coastguard Worker                          const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& rhs)
570*89c4ff92SAndroid Build Coastguard Worker                       {
571*89c4ff92SAndroid Build Coastguard Worker                           return lhs.first->m_OutputSlotId < rhs.first->m_OutputSlotId;
572*89c4ff92SAndroid Build Coastguard Worker                       });
573*89c4ff92SAndroid Build Coastguard Worker         }
574*89c4ff92SAndroid Build Coastguard Worker     }
575*89c4ff92SAndroid Build Coastguard Worker 
576*89c4ff92SAndroid Build Coastguard Worker     // Now that the intermediate tensor memory has been set-up,
577*89c4ff92SAndroid Build Coastguard Worker     // do any post allocation configuration for each workload.
578*89c4ff92SAndroid Build Coastguard Worker     if (!networkProperties.m_AsyncEnabled)
579*89c4ff92SAndroid Build Coastguard Worker     {
580*89c4ff92SAndroid Build Coastguard Worker         if (useInternalMemoryManager)
581*89c4ff92SAndroid Build Coastguard Worker         {
582*89c4ff92SAndroid Build Coastguard Worker             // Set up memory.
583*89c4ff92SAndroid Build Coastguard Worker             m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().AllocateDynamicBuffers();
584*89c4ff92SAndroid Build Coastguard Worker         }
585*89c4ff92SAndroid Build Coastguard Worker 
586*89c4ff92SAndroid Build Coastguard Worker         for (auto &workload : m_WorkloadQueue)
587*89c4ff92SAndroid Build Coastguard Worker         {
588*89c4ff92SAndroid Build Coastguard Worker             workload->PostAllocationConfigure();
589*89c4ff92SAndroid Build Coastguard Worker         }
590*89c4ff92SAndroid Build Coastguard Worker     }
591*89c4ff92SAndroid Build Coastguard Worker 
592*89c4ff92SAndroid Build Coastguard Worker     if (useExternalMemoryManager)
593*89c4ff92SAndroid Build Coastguard Worker     {
594*89c4ff92SAndroid Build Coastguard Worker         if (!networkProperties.m_AsyncEnabled)
595*89c4ff92SAndroid Build Coastguard Worker         {
596*89c4ff92SAndroid Build Coastguard Worker             AllocateAndExecuteConstantWorkloads();
597*89c4ff92SAndroid Build Coastguard Worker         }
598*89c4ff92SAndroid Build Coastguard Worker         else
599*89c4ff92SAndroid Build Coastguard Worker         {
600*89c4ff92SAndroid Build Coastguard Worker             AllocateAndExecuteConstantWorkloadsAsync();
601*89c4ff92SAndroid Build Coastguard Worker         }
602*89c4ff92SAndroid Build Coastguard Worker     }
603*89c4ff92SAndroid Build Coastguard Worker     // If synchronous, execute all constant layer workloads
604*89c4ff92SAndroid Build Coastguard Worker     if (!networkProperties.m_AsyncEnabled)
605*89c4ff92SAndroid Build Coastguard Worker     {
606*89c4ff92SAndroid Build Coastguard Worker         for (auto workload: ConstWorkloads)
607*89c4ff92SAndroid Build Coastguard Worker         {
608*89c4ff92SAndroid Build Coastguard Worker             workload->Execute();
609*89c4ff92SAndroid Build Coastguard Worker         }
610*89c4ff92SAndroid Build Coastguard Worker     }
611*89c4ff92SAndroid Build Coastguard Worker }
612*89c4ff92SAndroid Build Coastguard Worker 
AllocateAndExecuteConstantWorkloads()613*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateAndExecuteConstantWorkloads()
614*89c4ff92SAndroid Build Coastguard Worker {
615*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_AllocateAndExecuteConstants");
616*89c4ff92SAndroid Build Coastguard Worker     for (auto& pair : m_ConstantWorkloads)
617*89c4ff92SAndroid Build Coastguard Worker     {
618*89c4ff92SAndroid Build Coastguard Worker         auto tensorHandle = m_ConstantTensorHandles[pair.first];
619*89c4ff92SAndroid Build Coastguard Worker         tensorHandle->Allocate();
620*89c4ff92SAndroid Build Coastguard Worker         pair.second->Execute();
621*89c4ff92SAndroid Build Coastguard Worker     }
622*89c4ff92SAndroid Build Coastguard Worker }
623*89c4ff92SAndroid Build Coastguard Worker 
AllocateAndExecuteConstantWorkloadsAsync()624*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateAndExecuteConstantWorkloadsAsync()
625*89c4ff92SAndroid Build Coastguard Worker {
626*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_AllocateAndExecuteConstants");
627*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
628*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
629*89c4ff92SAndroid Build Coastguard Worker     {
630*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == LayerType::Constant)
631*89c4ff92SAndroid Build Coastguard Worker         {
632*89c4ff92SAndroid Build Coastguard Worker             const auto& outSlot = layer->GetOutputSlots()[0];
633*89c4ff92SAndroid Build Coastguard Worker             const auto factoryId = outSlot.GetTensorHandleFactoryId();
634*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(factoryId != ITensorHandleFactory::LegacyFactoryId);
635*89c4ff92SAndroid Build Coastguard Worker             auto& workloadFactory = GetWorkloadFactory(*layer);
636*89c4ff92SAndroid Build Coastguard Worker 
637*89c4ff92SAndroid Build Coastguard Worker             layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory);
638*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* tensorHandle = outSlot.GetOutputHandler().GetData();
639*89c4ff92SAndroid Build Coastguard Worker 
640*89c4ff92SAndroid Build Coastguard Worker             m_ConstantTensorHandles[layer->GetGuid()] = tensorHandle;
641*89c4ff92SAndroid Build Coastguard Worker             tensorHandle->Allocate();
642*89c4ff92SAndroid Build Coastguard Worker 
643*89c4ff92SAndroid Build Coastguard Worker             auto& backend = m_Backends.at(layer->GetBackendId());
644*89c4ff92SAndroid Build Coastguard Worker 
645*89c4ff92SAndroid Build Coastguard Worker             WorkingMemDescriptor memDesc;
646*89c4ff92SAndroid Build Coastguard Worker             memDesc.m_Outputs.push_back(tensorHandle);
647*89c4ff92SAndroid Build Coastguard Worker 
648*89c4ff92SAndroid Build Coastguard Worker             ExecutionData executionData = backend->CreateExecutionData(memDesc);
649*89c4ff92SAndroid Build Coastguard Worker             m_ConstantWorkloads[layer->GetGuid()]->ExecuteAsync(executionData);
650*89c4ff92SAndroid Build Coastguard Worker         }
651*89c4ff92SAndroid Build Coastguard Worker     }
652*89c4ff92SAndroid Build Coastguard Worker }
653*89c4ff92SAndroid Build Coastguard Worker 
SendNetworkStructure(arm::pipe::IProfilingService & profilingService)654*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::SendNetworkStructure(arm::pipe::IProfilingService& profilingService)
655*89c4ff92SAndroid Build Coastguard Worker {
656*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_SendNetworkStructure");
657*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
658*89c4ff92SAndroid Build Coastguard Worker     ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
659*89c4ff92SAndroid Build Coastguard Worker 
660*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TimelineUtilityMethods> timelineUtils =
661*89c4ff92SAndroid Build Coastguard Worker         TimelineUtilityMethods::GetTimelineUtils(profilingService);
662*89c4ff92SAndroid Build Coastguard Worker 
663*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID);
664*89c4ff92SAndroid Build Coastguard Worker 
665*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
666*89c4ff92SAndroid Build Coastguard Worker     {
667*89c4ff92SAndroid Build Coastguard Worker         // Add layer to the post-optimisation network structure
668*89c4ff92SAndroid Build Coastguard Worker         AddLayerStructure(timelineUtils, *layer, networkGuid);
669*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
670*89c4ff92SAndroid Build Coastguard Worker         {
671*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Input:
672*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Output:
673*89c4ff92SAndroid Build Coastguard Worker             {
674*89c4ff92SAndroid Build Coastguard Worker                 // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput().
675*89c4ff92SAndroid Build Coastguard Worker                 break;
676*89c4ff92SAndroid Build Coastguard Worker             }
677*89c4ff92SAndroid Build Coastguard Worker             default:
678*89c4ff92SAndroid Build Coastguard Worker             {
679*89c4ff92SAndroid Build Coastguard Worker                 for (auto& workload : m_WorkloadQueue)
680*89c4ff92SAndroid Build Coastguard Worker                 {
681*89c4ff92SAndroid Build Coastguard Worker                     // Add workload to the post-optimisation network structure
682*89c4ff92SAndroid Build Coastguard Worker                     AddWorkloadStructure(timelineUtils, workload, *layer);
683*89c4ff92SAndroid Build Coastguard Worker                 }
684*89c4ff92SAndroid Build Coastguard Worker             break;
685*89c4ff92SAndroid Build Coastguard Worker             }
686*89c4ff92SAndroid Build Coastguard Worker         }
687*89c4ff92SAndroid Build Coastguard Worker     }
688*89c4ff92SAndroid Build Coastguard Worker     // Commit to send the post-optimisation network structure
689*89c4ff92SAndroid Build Coastguard Worker     timelineUtils->Commit();
690*89c4ff92SAndroid Build Coastguard Worker }
691*89c4ff92SAndroid Build Coastguard Worker 
GetNetworkGuid()692*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid LoadedNetwork::GetNetworkGuid()
693*89c4ff92SAndroid Build Coastguard Worker {
694*89c4ff92SAndroid Build Coastguard Worker     return m_OptimizedNetwork->GetGuid();
695*89c4ff92SAndroid Build Coastguard Worker }
696*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensorInfo(LayerBindingId layerId) const697*89c4ff92SAndroid Build Coastguard Worker TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const
698*89c4ff92SAndroid Build Coastguard Worker {
699*89c4ff92SAndroid Build Coastguard Worker     for (auto&& inputLayer : m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetInputLayers())
700*89c4ff92SAndroid Build Coastguard Worker     {
701*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(inputLayer->GetNumOutputSlots() == 1, "Input layer should have exactly 1 output slot");
702*89c4ff92SAndroid Build Coastguard Worker         if (inputLayer->GetBindingId() == layerId)
703*89c4ff92SAndroid Build Coastguard Worker         {
704*89c4ff92SAndroid Build Coastguard Worker             return inputLayer->GetOutputSlot(0).GetTensorInfo();
705*89c4ff92SAndroid Build Coastguard Worker         }
706*89c4ff92SAndroid Build Coastguard Worker     }
707*89c4ff92SAndroid Build Coastguard Worker 
708*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException(fmt::format("No input layer is associated with id {}", layerId));
709*89c4ff92SAndroid Build Coastguard Worker }
710*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensorInfo(LayerBindingId layerId) const711*89c4ff92SAndroid Build Coastguard Worker TensorInfo LoadedNetwork::GetOutputTensorInfo(LayerBindingId layerId) const
712*89c4ff92SAndroid Build Coastguard Worker {
713*89c4ff92SAndroid Build Coastguard Worker     for (auto&& outputLayer : m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetOutputLayers())
714*89c4ff92SAndroid Build Coastguard Worker     {
715*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(outputLayer->GetNumInputSlots() == 1, "Output layer should have exactly 1 input slot");
716*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(outputLayer->GetInputSlot(0).GetConnection(), "Input slot on Output layer must be connected");
717*89c4ff92SAndroid Build Coastguard Worker         if (outputLayer->GetBindingId() == layerId)
718*89c4ff92SAndroid Build Coastguard Worker         {
719*89c4ff92SAndroid Build Coastguard Worker             return outputLayer->GetInputSlot(0).GetConnection()->GetTensorInfo();
720*89c4ff92SAndroid Build Coastguard Worker         }
721*89c4ff92SAndroid Build Coastguard Worker     }
722*89c4ff92SAndroid Build Coastguard Worker 
723*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException(fmt::format("No output layer is associated with id {}", layerId));
724*89c4ff92SAndroid Build Coastguard Worker }
725*89c4ff92SAndroid Build Coastguard Worker 
GetWorkloadFactory(const Layer & layer) const726*89c4ff92SAndroid Build Coastguard Worker const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) const
727*89c4ff92SAndroid Build Coastguard Worker {
728*89c4ff92SAndroid Build Coastguard Worker     const IWorkloadFactory* workloadFactory = nullptr;
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     auto it = m_WorkloadFactories.find(layer.GetBackendId());
731*89c4ff92SAndroid Build Coastguard Worker     if (it ==  m_WorkloadFactories.end())
732*89c4ff92SAndroid Build Coastguard Worker     {
733*89c4ff92SAndroid Build Coastguard Worker         throw RuntimeException(fmt::format("No workload factory for {0} to be used for layer: {1}",
734*89c4ff92SAndroid Build Coastguard Worker                                            layer.GetBackendId().Get(),
735*89c4ff92SAndroid Build Coastguard Worker                                            layer.GetNameStr()),
736*89c4ff92SAndroid Build Coastguard Worker                                            CHECK_LOCATION());
737*89c4ff92SAndroid Build Coastguard Worker     }
738*89c4ff92SAndroid Build Coastguard Worker 
739*89c4ff92SAndroid Build Coastguard Worker     workloadFactory = it->second.get();
740*89c4ff92SAndroid Build Coastguard Worker 
741*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(workloadFactory, "No workload factory");
742*89c4ff92SAndroid Build Coastguard Worker 
743*89c4ff92SAndroid Build Coastguard Worker     return *workloadFactory;
744*89c4ff92SAndroid Build Coastguard Worker }
745*89c4ff92SAndroid Build Coastguard Worker 
746*89c4ff92SAndroid Build Coastguard Worker namespace {
747*89c4ff92SAndroid Build Coastguard Worker 
748*89c4ff92SAndroid Build Coastguard Worker // Non-copyable class owning accelerator-specific tensor data.
749*89c4ff92SAndroid Build Coastguard Worker class TensorPin
750*89c4ff92SAndroid Build Coastguard Worker {
751*89c4ff92SAndroid Build Coastguard Worker public:
TensorPin(std::unique_ptr<ITensorHandle> handle,const TensorInfo & info,LayerBindingId id)752*89c4ff92SAndroid Build Coastguard Worker     TensorPin(std::unique_ptr<ITensorHandle> handle, const TensorInfo& info, LayerBindingId id)
753*89c4ff92SAndroid Build Coastguard Worker         : m_TensorHandle(std::move(handle))
754*89c4ff92SAndroid Build Coastguard Worker         , m_TensorInfo(info)
755*89c4ff92SAndroid Build Coastguard Worker         , m_Id(id)
756*89c4ff92SAndroid Build Coastguard Worker     {
757*89c4ff92SAndroid Build Coastguard Worker     }
758*89c4ff92SAndroid Build Coastguard Worker 
GetTensorHandle() const759*89c4ff92SAndroid Build Coastguard Worker     ITensorHandle* GetTensorHandle() const { return m_TensorHandle.get(); }
GetTensorInfo() const760*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
GetBindingId() const761*89c4ff92SAndroid Build Coastguard Worker     LayerBindingId GetBindingId() const { return m_Id; }
762*89c4ff92SAndroid Build Coastguard Worker 
763*89c4ff92SAndroid Build Coastguard Worker private:
764*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> m_TensorHandle;
765*89c4ff92SAndroid Build Coastguard Worker     TensorInfo m_TensorInfo;
766*89c4ff92SAndroid Build Coastguard Worker     LayerBindingId m_Id;
767*89c4ff92SAndroid Build Coastguard Worker };
768*89c4ff92SAndroid Build Coastguard Worker 
GetTensorPin(LayerBindingId id,const std::vector<TensorPin> & pins,char const * bindingPointDesc)769*89c4ff92SAndroid Build Coastguard Worker static const TensorPin& GetTensorPin(LayerBindingId id,
770*89c4ff92SAndroid Build Coastguard Worker     const std::vector<TensorPin>& pins,
771*89c4ff92SAndroid Build Coastguard Worker     char const* bindingPointDesc)
772*89c4ff92SAndroid Build Coastguard Worker {
773*89c4ff92SAndroid Build Coastguard Worker     auto it = std::find_if(pins.begin(), pins.end(),
774*89c4ff92SAndroid Build Coastguard Worker         [id](const TensorPin& pin)
775*89c4ff92SAndroid Build Coastguard Worker     {
776*89c4ff92SAndroid Build Coastguard Worker         return pin.GetBindingId() == id;
777*89c4ff92SAndroid Build Coastguard Worker     });
778*89c4ff92SAndroid Build Coastguard Worker 
779*89c4ff92SAndroid Build Coastguard Worker     if (it != pins.end())
780*89c4ff92SAndroid Build Coastguard Worker     {
781*89c4ff92SAndroid Build Coastguard Worker         return *it;
782*89c4ff92SAndroid Build Coastguard Worker     }
783*89c4ff92SAndroid Build Coastguard Worker     else
784*89c4ff92SAndroid Build Coastguard Worker     {
785*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("No tensor supplied for {0} {1}", bindingPointDesc, id));
786*89c4ff92SAndroid Build Coastguard Worker     }
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker 
789*89c4ff92SAndroid Build Coastguard Worker // Stores data that needs to be kept accessible for the entire execution of a workload.
790*89c4ff92SAndroid Build Coastguard Worker class WorkloadData
791*89c4ff92SAndroid Build Coastguard Worker {
792*89c4ff92SAndroid Build Coastguard Worker public:
WorkloadData(const InputTensors & inputTensors,const OutputTensors & outputTensors)793*89c4ff92SAndroid Build Coastguard Worker     WorkloadData(const InputTensors& inputTensors, const OutputTensors& outputTensors)
794*89c4ff92SAndroid Build Coastguard Worker     {
795*89c4ff92SAndroid Build Coastguard Worker         m_InputTensorPins.reserve(inputTensors.size());
796*89c4ff92SAndroid Build Coastguard Worker         m_OutputTensorPins.reserve(outputTensors.size());
797*89c4ff92SAndroid Build Coastguard Worker 
798*89c4ff92SAndroid Build Coastguard Worker         for (auto inputTensorPair : inputTensors)
799*89c4ff92SAndroid Build Coastguard Worker         {
800*89c4ff92SAndroid Build Coastguard Worker             auto inputTensor = inputTensorPair.second;
801*89c4ff92SAndroid Build Coastguard Worker 
802*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<ITensorHandle> tensorHandle =
803*89c4ff92SAndroid Build Coastguard Worker                 std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(),inputTensor.GetMemoryArea());
804*89c4ff92SAndroid Build Coastguard Worker             LayerBindingId layerId = inputTensorPair.first;
805*89c4ff92SAndroid Build Coastguard Worker 
806*89c4ff92SAndroid Build Coastguard Worker             m_InputTensorPins.emplace_back(std::move(tensorHandle), inputTensor.GetInfo(), layerId);
807*89c4ff92SAndroid Build Coastguard Worker         }
808*89c4ff92SAndroid Build Coastguard Worker 
809*89c4ff92SAndroid Build Coastguard Worker         for (auto outputTensorPair : outputTensors)
810*89c4ff92SAndroid Build Coastguard Worker         {
811*89c4ff92SAndroid Build Coastguard Worker             auto outputTensor = outputTensorPair.second;
812*89c4ff92SAndroid Build Coastguard Worker 
813*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<ITensorHandle> tensorHandle =
814*89c4ff92SAndroid Build Coastguard Worker                 std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(), outputTensor.GetMemoryArea());
815*89c4ff92SAndroid Build Coastguard Worker             LayerBindingId layerId = outputTensorPair.first;
816*89c4ff92SAndroid Build Coastguard Worker 
817*89c4ff92SAndroid Build Coastguard Worker             m_OutputTensorPins.emplace_back(std::move(tensorHandle), outputTensor.GetInfo(), layerId);
818*89c4ff92SAndroid Build Coastguard Worker         }
819*89c4ff92SAndroid Build Coastguard Worker     }
820*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensorPin(LayerBindingId id) const821*89c4ff92SAndroid Build Coastguard Worker     const TensorPin& GetInputTensorPin(LayerBindingId id) const
822*89c4ff92SAndroid Build Coastguard Worker     {
823*89c4ff92SAndroid Build Coastguard Worker         return GetTensorPin(id, m_InputTensorPins, "input");
824*89c4ff92SAndroid Build Coastguard Worker     }
825*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensorPin(LayerBindingId id) const826*89c4ff92SAndroid Build Coastguard Worker     const TensorPin& GetOutputTensorPin(LayerBindingId id) const
827*89c4ff92SAndroid Build Coastguard Worker     {
828*89c4ff92SAndroid Build Coastguard Worker         return GetTensorPin(id, m_OutputTensorPins, "output");
829*89c4ff92SAndroid Build Coastguard Worker     }
830*89c4ff92SAndroid Build Coastguard Worker 
831*89c4ff92SAndroid Build Coastguard Worker private:
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorPin> m_InputTensorPins;
834*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorPin> m_OutputTensorPins;
835*89c4ff92SAndroid Build Coastguard Worker };
836*89c4ff92SAndroid Build Coastguard Worker 
837*89c4ff92SAndroid Build Coastguard Worker }
838*89c4ff92SAndroid Build Coastguard Worker 
EnqueueWorkload(const InputTensors & inputTensors,const OutputTensors & outputTensors,std::vector<ImportedInputId> preImportedInputIds,std::vector<ImportedOutputId> preImportedOutputIds)839*89c4ff92SAndroid Build Coastguard Worker Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
840*89c4ff92SAndroid Build Coastguard Worker                                       const OutputTensors& outputTensors,
841*89c4ff92SAndroid Build Coastguard Worker                                       std::vector<ImportedInputId> preImportedInputIds,
842*89c4ff92SAndroid Build Coastguard Worker                                       std::vector<ImportedOutputId> preImportedOutputIds)
843*89c4ff92SAndroid Build Coastguard Worker {
844*89c4ff92SAndroid Build Coastguard Worker     const Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
845*89c4ff92SAndroid Build Coastguard Worker 
846*89c4ff92SAndroid Build Coastguard Worker     // Walk graph to determine the order of execution.
847*89c4ff92SAndroid Build Coastguard Worker     if (graph.GetNumLayers() < 2)
848*89c4ff92SAndroid Build Coastguard Worker     {
849*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(warning) << "IRuntime::EnqueueWorkload()::Less than two nodes in graph";
850*89c4ff92SAndroid Build Coastguard Worker         return Status::Failure;
851*89c4ff92SAndroid Build Coastguard Worker     }
852*89c4ff92SAndroid Build Coastguard Worker 
853*89c4ff92SAndroid Build Coastguard Worker     // Data that must be kept alive for the entire execution of the workload.
854*89c4ff92SAndroid Build Coastguard Worker     WorkloadData workloadData(inputTensors, outputTensors);
855*89c4ff92SAndroid Build Coastguard Worker 
856*89c4ff92SAndroid Build Coastguard Worker     // Input tensors can be provided as parameters or pre imported. Either way the number of
857*89c4ff92SAndroid Build Coastguard Worker     // tensors should match the number of inputs.
858*89c4ff92SAndroid Build Coastguard Worker     if (graph.GetNumInputs() != (inputTensors.size() + preImportedInputIds.size()))
859*89c4ff92SAndroid Build Coastguard Worker     {
860*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("Number of inputs provided does not match network.");
861*89c4ff92SAndroid Build Coastguard Worker     }
862*89c4ff92SAndroid Build Coastguard Worker 
863*89c4ff92SAndroid Build Coastguard Worker     // For each input to the network, call EnqueueInput with the data passed by the user.
864*89c4ff92SAndroid Build Coastguard Worker     {
865*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs");
866*89c4ff92SAndroid Build Coastguard Worker         m_InputQueue.clear();
867*89c4ff92SAndroid Build Coastguard Worker         m_InputQueue.reserve(graph.GetNumInputs());
868*89c4ff92SAndroid Build Coastguard Worker 
869*89c4ff92SAndroid Build Coastguard Worker         unsigned int inputIndex = 0;
870*89c4ff92SAndroid Build Coastguard Worker         unsigned int importedInputIdIndex = 0;
871*89c4ff92SAndroid Build Coastguard Worker         std::sort(preImportedInputIds.begin(), preImportedInputIds.end());
872*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* inputLayer : graph.GetInputLayers())
873*89c4ff92SAndroid Build Coastguard Worker         {
874*89c4ff92SAndroid Build Coastguard Worker             if (importedInputIdIndex < preImportedInputIds.size() &&
875*89c4ff92SAndroid Build Coastguard Worker                 inputIndex == preImportedInputIds[importedInputIdIndex])
876*89c4ff92SAndroid Build Coastguard Worker             {
877*89c4ff92SAndroid Build Coastguard Worker                 // Only replace tensorhandles if they have not already been replaced
878*89c4ff92SAndroid Build Coastguard Worker                 if (!m_IsInputImported[inputIndex])
879*89c4ff92SAndroid Build Coastguard Worker                 {
880*89c4ff92SAndroid Build Coastguard Worker                     auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
881*89c4ff92SAndroid Build Coastguard Worker 
882*89c4ff92SAndroid Build Coastguard Worker                     for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
883*89c4ff92SAndroid Build Coastguard Worker                     {
884*89c4ff92SAndroid Build Coastguard Worker                         auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
885*89c4ff92SAndroid Build Coastguard Worker                         workload->ReplaceInputTensorHandle(outputTensorHandle, workloadInfo.m_SlotIndex);
886*89c4ff92SAndroid Build Coastguard Worker                     }
887*89c4ff92SAndroid Build Coastguard Worker                     m_IsInputImported[inputIndex] = true;
888*89c4ff92SAndroid Build Coastguard Worker                 }
889*89c4ff92SAndroid Build Coastguard Worker                 importedInputIdIndex++;
890*89c4ff92SAndroid Build Coastguard Worker             }
891*89c4ff92SAndroid Build Coastguard Worker             else
892*89c4ff92SAndroid Build Coastguard Worker             {
893*89c4ff92SAndroid Build Coastguard Worker                 if (m_IsInputImported[inputIndex])
894*89c4ff92SAndroid Build Coastguard Worker                 {
895*89c4ff92SAndroid Build Coastguard Worker                     OutputHandler& handler = const_cast<OutputHandler&>(inputLayer->GetOutputHandler(0));
896*89c4ff92SAndroid Build Coastguard Worker 
897*89c4ff92SAndroid Build Coastguard Worker                     for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
898*89c4ff92SAndroid Build Coastguard Worker                     {
899*89c4ff92SAndroid Build Coastguard Worker                         auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
900*89c4ff92SAndroid Build Coastguard Worker                         workload->ReplaceInputTensorHandle(handler.GetData(), workloadInfo.m_SlotIndex);
901*89c4ff92SAndroid Build Coastguard Worker                     }
902*89c4ff92SAndroid Build Coastguard Worker 
903*89c4ff92SAndroid Build Coastguard Worker                     m_IsInputImported[inputIndex] = false;
904*89c4ff92SAndroid Build Coastguard Worker                 }
905*89c4ff92SAndroid Build Coastguard Worker 
906*89c4ff92SAndroid Build Coastguard Worker                 // InputTensorHandle is not imported yet, process to enqueue input
907*89c4ff92SAndroid Build Coastguard Worker                 const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId());
908*89c4ff92SAndroid Build Coastguard Worker                 EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
909*89c4ff92SAndroid Build Coastguard Worker             }
910*89c4ff92SAndroid Build Coastguard Worker             inputIndex++;
911*89c4ff92SAndroid Build Coastguard Worker         }
912*89c4ff92SAndroid Build Coastguard Worker     }
913*89c4ff92SAndroid Build Coastguard Worker     // For each output to the network, call EnqueueOutput with the data passed by the user.
914*89c4ff92SAndroid Build Coastguard Worker     {
915*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs");
916*89c4ff92SAndroid Build Coastguard Worker         m_OutputQueue.clear();
917*89c4ff92SAndroid Build Coastguard Worker         m_OutputQueue.reserve(graph.GetNumOutputs());
918*89c4ff92SAndroid Build Coastguard Worker 
919*89c4ff92SAndroid Build Coastguard Worker         if (preImportedOutputIds.size() > graph.GetNumOutputs())
920*89c4ff92SAndroid Build Coastguard Worker         {
921*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("Invalid number of preImportedOutputIds");
922*89c4ff92SAndroid Build Coastguard Worker         }
923*89c4ff92SAndroid Build Coastguard Worker 
924*89c4ff92SAndroid Build Coastguard Worker         unsigned int outputIndex = 0;
925*89c4ff92SAndroid Build Coastguard Worker         unsigned int importedOutputIdIndex = 0;
926*89c4ff92SAndroid Build Coastguard Worker         std::sort(preImportedOutputIds.begin(), preImportedOutputIds.end());
927*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* outputLayer : graph.GetOutputLayers())
928*89c4ff92SAndroid Build Coastguard Worker         {
929*89c4ff92SAndroid Build Coastguard Worker             if (importedOutputIdIndex < preImportedOutputIds.size() &&
930*89c4ff92SAndroid Build Coastguard Worker                 outputIndex == preImportedOutputIds[importedOutputIdIndex])
931*89c4ff92SAndroid Build Coastguard Worker             {
932*89c4ff92SAndroid Build Coastguard Worker                 // Only replace tensorhandles if they have not already been replaced
933*89c4ff92SAndroid Build Coastguard Worker                 ITensorHandle* inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
934*89c4ff92SAndroid Build Coastguard Worker 
935*89c4ff92SAndroid Build Coastguard Worker                 if (!m_IsOutputImported[outputIndex])
936*89c4ff92SAndroid Build Coastguard Worker                 {
937*89c4ff92SAndroid Build Coastguard Worker                     const auto bindingId = outputLayer->GetBindingId();
938*89c4ff92SAndroid Build Coastguard Worker                     const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
939*89c4ff92SAndroid Build Coastguard Worker 
940*89c4ff92SAndroid Build Coastguard Worker                     auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
941*89c4ff92SAndroid Build Coastguard Worker 
942*89c4ff92SAndroid Build Coastguard Worker                     outputWorkload->ReplaceOutputTensorHandle(inputTensorHandle,
943*89c4ff92SAndroid Build Coastguard Worker                                                               indices.m_OutputSlotIndices.m_SlotIndex);
944*89c4ff92SAndroid Build Coastguard Worker 
945*89c4ff92SAndroid Build Coastguard Worker                     for (const auto& workloadInfo: indices.m_InputSlotIndices)
946*89c4ff92SAndroid Build Coastguard Worker                     {
947*89c4ff92SAndroid Build Coastguard Worker                         auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
948*89c4ff92SAndroid Build Coastguard Worker                         inputWorkload->ReplaceInputTensorHandle(inputTensorHandle, workloadInfo.m_SlotIndex);
949*89c4ff92SAndroid Build Coastguard Worker                     }
950*89c4ff92SAndroid Build Coastguard Worker                     m_IsOutputImported[outputIndex] = true;
951*89c4ff92SAndroid Build Coastguard Worker                 }
952*89c4ff92SAndroid Build Coastguard Worker 
953*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated.");
954*89c4ff92SAndroid Build Coastguard Worker                 MemSyncQueueDescriptor syncDesc;
955*89c4ff92SAndroid Build Coastguard Worker                 syncDesc.m_Inputs.push_back(inputTensorHandle);
956*89c4ff92SAndroid Build Coastguard Worker                 WorkloadInfo info;
957*89c4ff92SAndroid Build Coastguard Worker                 info.m_InputTensorInfos.push_back(
958*89c4ff92SAndroid Build Coastguard Worker                         outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo());
959*89c4ff92SAndroid Build Coastguard Worker                 auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
960*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
961*89c4ff92SAndroid Build Coastguard Worker                 m_OutputQueue.push_back(move(syncWorkload));
962*89c4ff92SAndroid Build Coastguard Worker                 importedOutputIdIndex++;
963*89c4ff92SAndroid Build Coastguard Worker             }
964*89c4ff92SAndroid Build Coastguard Worker             else
965*89c4ff92SAndroid Build Coastguard Worker             {
966*89c4ff92SAndroid Build Coastguard Worker                 if (m_IsOutputImported[outputIndex])
967*89c4ff92SAndroid Build Coastguard Worker                 {
968*89c4ff92SAndroid Build Coastguard Worker                     const auto bindingId = outputLayer->GetBindingId();
969*89c4ff92SAndroid Build Coastguard Worker                     const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
970*89c4ff92SAndroid Build Coastguard Worker 
971*89c4ff92SAndroid Build Coastguard Worker                     auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
972*89c4ff92SAndroid Build Coastguard Worker                     const OutputHandler& outputHandler =
973*89c4ff92SAndroid Build Coastguard Worker                             outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOutputHandler();
974*89c4ff92SAndroid Build Coastguard Worker 
975*89c4ff92SAndroid Build Coastguard Worker                     outputWorkload->ReplaceOutputTensorHandle(
976*89c4ff92SAndroid Build Coastguard Worker                             outputHandler.GetData(), indices.m_OutputSlotIndices.m_SlotIndex);
977*89c4ff92SAndroid Build Coastguard Worker 
978*89c4ff92SAndroid Build Coastguard Worker                     for (const auto& workloadInfo: indices.m_InputSlotIndices)
979*89c4ff92SAndroid Build Coastguard Worker                     {
980*89c4ff92SAndroid Build Coastguard Worker                         auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
981*89c4ff92SAndroid Build Coastguard Worker                         inputWorkload->ReplaceInputTensorHandle(outputHandler.GetData(), workloadInfo.m_SlotIndex);
982*89c4ff92SAndroid Build Coastguard Worker                     }
983*89c4ff92SAndroid Build Coastguard Worker                     m_IsOutputImported[outputIndex] = false;
984*89c4ff92SAndroid Build Coastguard Worker                 }
985*89c4ff92SAndroid Build Coastguard Worker 
986*89c4ff92SAndroid Build Coastguard Worker                 const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId());
987*89c4ff92SAndroid Build Coastguard Worker                 // OutputTensorHandle is not imported yet, process to enqueue Output
988*89c4ff92SAndroid Build Coastguard Worker                 EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
989*89c4ff92SAndroid Build Coastguard Worker             }
990*89c4ff92SAndroid Build Coastguard Worker             outputIndex++;
991*89c4ff92SAndroid Build Coastguard Worker         }
992*89c4ff92SAndroid Build Coastguard Worker     }
993*89c4ff92SAndroid Build Coastguard Worker 
994*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TimelineUtilityMethods> timelineUtils =
995*89c4ff92SAndroid Build Coastguard Worker                         TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
996*89c4ff92SAndroid Build Coastguard Worker     ProfilingGuid inferenceGuid = m_ProfilingService->GetNextGuid();
997*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
998*89c4ff92SAndroid Build Coastguard Worker     {
999*89c4ff92SAndroid Build Coastguard Worker         // Add inference timeline trace if profiling is enabled.
1000*89c4ff92SAndroid Build Coastguard Worker         ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
1001*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateTypedEntity(inferenceGuid, LabelsAndEventClasses::INFERENCE_GUID);
1002*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
1003*89c4ff92SAndroid Build Coastguard Worker                                           networkGuid,
1004*89c4ff92SAndroid Build Coastguard Worker                                           inferenceGuid,
1005*89c4ff92SAndroid Build Coastguard Worker                                           LabelsAndEventClasses::EXECUTION_OF_GUID);
1006*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
1007*89c4ff92SAndroid Build Coastguard Worker     }
1008*89c4ff92SAndroid Build Coastguard Worker 
1009*89c4ff92SAndroid Build Coastguard Worker     bool executionSucceeded = true;
1010*89c4ff92SAndroid Build Coastguard Worker 
1011*89c4ff92SAndroid Build Coastguard Worker     {
1012*89c4ff92SAndroid Build Coastguard Worker         if (m_ProfilingService->IsProfilingEnabled())
1013*89c4ff92SAndroid Build Coastguard Worker         {
1014*89c4ff92SAndroid Build Coastguard Worker             m_ProfilingService->IncrementCounterValue(INFERENCES_RUN);
1015*89c4ff92SAndroid Build Coastguard Worker         }
1016*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
1017*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_HEAP_PROFILING("Executing");
1018*89c4ff92SAndroid Build Coastguard Worker         executionSucceeded = Execute(timelineUtils, inferenceGuid);
1019*89c4ff92SAndroid Build Coastguard Worker     }
1020*89c4ff92SAndroid Build Coastguard Worker 
1021*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
1022*89c4ff92SAndroid Build Coastguard Worker     {
1023*89c4ff92SAndroid Build Coastguard Worker         // Add end of life of the inference timeline if profiling is enabled.
1024*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
1025*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->Commit();
1026*89c4ff92SAndroid Build Coastguard Worker     }
1027*89c4ff92SAndroid Build Coastguard Worker 
1028*89c4ff92SAndroid Build Coastguard Worker     return executionSucceeded ? Status::Success : Status::Failure;
1029*89c4ff92SAndroid Build Coastguard Worker }
1030*89c4ff92SAndroid Build Coastguard Worker 
EnqueueInput(const BindableLayer & layer,ITensorHandle * tensorHandle,const TensorInfo & tensorInfo)1031*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo)
1032*89c4ff92SAndroid Build Coastguard Worker {
1033*89c4ff92SAndroid Build Coastguard Worker     if (layer.GetType() != LayerType::Input)
1034*89c4ff92SAndroid Build Coastguard Worker     {
1035*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("EnqueueInput: given layer not an InputLayer");
1036*89c4ff92SAndroid Build Coastguard Worker     }
1037*89c4ff92SAndroid Build Coastguard Worker 
1038*89c4ff92SAndroid Build Coastguard Worker     if (tensorHandle == nullptr)
1039*89c4ff92SAndroid Build Coastguard Worker     {
1040*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("EnqueueInput: tensorHandle must not be NULL");
1041*89c4ff92SAndroid Build Coastguard Worker     }
1042*89c4ff92SAndroid Build Coastguard Worker 
1043*89c4ff92SAndroid Build Coastguard Worker     InputQueueDescriptor inputQueueDescriptor;
1044*89c4ff92SAndroid Build Coastguard Worker     WorkloadInfo info;
1045*89c4ff92SAndroid Build Coastguard Worker 
1046*89c4ff92SAndroid Build Coastguard Worker     inputQueueDescriptor.m_Inputs.push_back(tensorHandle);
1047*89c4ff92SAndroid Build Coastguard Worker     info.m_InputTensorInfos.push_back(tensorInfo);
1048*89c4ff92SAndroid Build Coastguard Worker 
1049*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(layer.GetNumOutputSlots() == 1, "Can only handle Input Layer with one output");
1050*89c4ff92SAndroid Build Coastguard Worker     const OutputHandler& handler = layer.GetOutputHandler();
1051*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& outputTensorInfo = handler.GetTensorInfo();
1052*89c4ff92SAndroid Build Coastguard Worker     ITensorHandle* outputTensorHandle = handler.GetData();
1053*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(outputTensorHandle != nullptr,
1054*89c4ff92SAndroid Build Coastguard Worker                      "Data should have been allocated.");
1055*89c4ff92SAndroid Build Coastguard Worker     inputQueueDescriptor.m_Outputs.push_back(outputTensorHandle);
1056*89c4ff92SAndroid Build Coastguard Worker     info.m_OutputTensorInfos.push_back(outputTensorInfo);
1057*89c4ff92SAndroid Build Coastguard Worker 
1058*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags importFlags = outputTensorHandle->GetImportFlags();
1059*89c4ff92SAndroid Build Coastguard Worker     bool needMemCopy = true;
1060*89c4ff92SAndroid Build Coastguard Worker     if (m_NetworkProperties.m_ImportEnabled)  // Try import the input tensor
1061*89c4ff92SAndroid Build Coastguard Worker     {
1062*89c4ff92SAndroid Build Coastguard Worker         if(CheckFlag(importFlags, m_NetworkProperties.m_InputSource))
1063*89c4ff92SAndroid Build Coastguard Worker         {
1064*89c4ff92SAndroid Build Coastguard Worker             needMemCopy = false;
1065*89c4ff92SAndroid Build Coastguard Worker             // This assumes a CPU Tensor handle
1066*89c4ff92SAndroid Build Coastguard Worker             void* mem = tensorHandle->Map(false);
1067*89c4ff92SAndroid Build Coastguard Worker             if (outputTensorHandle->Import(mem, m_NetworkProperties.m_InputSource))
1068*89c4ff92SAndroid Build Coastguard Worker             {
1069*89c4ff92SAndroid Build Coastguard Worker                 tensorHandle->Unmap();
1070*89c4ff92SAndroid Build Coastguard Worker                 return; // No need for a workload since the import has been done.
1071*89c4ff92SAndroid Build Coastguard Worker             }
1072*89c4ff92SAndroid Build Coastguard Worker             tensorHandle->Unmap();
1073*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("EnqueueInput: Memory Import failed");
1074*89c4ff92SAndroid Build Coastguard Worker         }
1075*89c4ff92SAndroid Build Coastguard Worker     }
1076*89c4ff92SAndroid Build Coastguard Worker     if (needMemCopy)
1077*89c4ff92SAndroid Build Coastguard Worker     {
1078*89c4ff92SAndroid Build Coastguard Worker         // Create a mem copy workload for input since we did not import
1079*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<IWorkload> inputWorkload = std::make_unique<CopyMemGenericWorkload>(inputQueueDescriptor, info);
1080*89c4ff92SAndroid Build Coastguard Worker 
1081*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(inputWorkload, "No input workload created");
1082*89c4ff92SAndroid Build Coastguard Worker 
1083*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1084*89c4ff92SAndroid Build Coastguard Worker                             TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1085*89c4ff92SAndroid Build Coastguard Worker         if (timelineUtils)
1086*89c4ff92SAndroid Build Coastguard Worker         {
1087*89c4ff92SAndroid Build Coastguard Worker             // Add Input Workload to the post-optimisation network structure
1088*89c4ff92SAndroid Build Coastguard Worker             AddWorkloadStructure(timelineUtils, inputWorkload, layer);
1089*89c4ff92SAndroid Build Coastguard Worker             timelineUtils->Commit();
1090*89c4ff92SAndroid Build Coastguard Worker         }
1091*89c4ff92SAndroid Build Coastguard Worker 
1092*89c4ff92SAndroid Build Coastguard Worker         m_InputQueue.push_back(move(inputWorkload));
1093*89c4ff92SAndroid Build Coastguard Worker     }
1094*89c4ff92SAndroid Build Coastguard Worker }
1095*89c4ff92SAndroid Build Coastguard Worker 
EnqueueOutput(const BindableLayer & layer,ITensorHandle * tensorHandle,const TensorInfo & tensorInfo)1096*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo)
1097*89c4ff92SAndroid Build Coastguard Worker {
1098*89c4ff92SAndroid Build Coastguard Worker     if (layer.GetType() != LayerType::Output)
1099*89c4ff92SAndroid Build Coastguard Worker     {
1100*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("EnqueueOutput: given layer not an OutputLayer");
1101*89c4ff92SAndroid Build Coastguard Worker     }
1102*89c4ff92SAndroid Build Coastguard Worker 
1103*89c4ff92SAndroid Build Coastguard Worker     if (tensorHandle == nullptr)
1104*89c4ff92SAndroid Build Coastguard Worker     {
1105*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("EnqueueOutput: tensorHandle must not be NULL");
1106*89c4ff92SAndroid Build Coastguard Worker     }
1107*89c4ff92SAndroid Build Coastguard Worker 
1108*89c4ff92SAndroid Build Coastguard Worker     OutputQueueDescriptor outputQueueDescriptor;
1109*89c4ff92SAndroid Build Coastguard Worker     WorkloadInfo info;
1110*89c4ff92SAndroid Build Coastguard Worker 
1111*89c4ff92SAndroid Build Coastguard Worker     outputQueueDescriptor.m_Outputs.push_back(tensorHandle);
1112*89c4ff92SAndroid Build Coastguard Worker     info.m_OutputTensorInfos.push_back(tensorInfo);
1113*89c4ff92SAndroid Build Coastguard Worker 
1114*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(layer.GetNumInputSlots() == 1, "Output Layer should have exactly one input.");
1115*89c4ff92SAndroid Build Coastguard Worker 
1116*89c4ff92SAndroid Build Coastguard Worker     // Gets the output handler from the previous node.
1117*89c4ff92SAndroid Build Coastguard Worker     const OutputHandler& outputHandler = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
1118*89c4ff92SAndroid Build Coastguard Worker 
1119*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo& inputTensorInfo = outputHandler.GetTensorInfo();
1120*89c4ff92SAndroid Build Coastguard Worker     ITensorHandle* inputTensorHandle = outputHandler.GetData();
1121*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated.");
1122*89c4ff92SAndroid Build Coastguard Worker 
1123*89c4ff92SAndroid Build Coastguard Worker     // Try import the output tensor.
1124*89c4ff92SAndroid Build Coastguard Worker     // Note: We can only import the output pointer if all of the following  hold true:
1125*89c4ff92SAndroid Build Coastguard Worker     // a) The imported pointer is aligned sufficiently
1126*89c4ff92SAndroid Build Coastguard Worker     // b) The tensor has zero padding
1127*89c4ff92SAndroid Build Coastguard Worker     // c) There is only one connection to the OutputSlot and it is to an OutputLayer.
1128*89c4ff92SAndroid Build Coastguard Worker     // d) The output pointer is allocated via malloc. (Other types will be supported in a later release)
1129*89c4ff92SAndroid Build Coastguard Worker     // e) m_IsExportEnabled must be set to true
1130*89c4ff92SAndroid Build Coastguard Worker     bool needMemCopy = true;
1131*89c4ff92SAndroid Build Coastguard Worker     if (m_NetworkProperties.m_ExportEnabled &&
1132*89c4ff92SAndroid Build Coastguard Worker         (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1))
1133*89c4ff92SAndroid Build Coastguard Worker     {
1134*89c4ff92SAndroid Build Coastguard Worker         if(layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer().GetType() != LayerType::Input)
1135*89c4ff92SAndroid Build Coastguard Worker         {
1136*89c4ff92SAndroid Build Coastguard Worker             MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
1137*89c4ff92SAndroid Build Coastguard Worker             if (CheckFlag(importFlags, m_NetworkProperties.m_OutputSource))
1138*89c4ff92SAndroid Build Coastguard Worker             {
1139*89c4ff92SAndroid Build Coastguard Worker                 needMemCopy = false;
1140*89c4ff92SAndroid Build Coastguard Worker                 void *mem = tensorHandle->Map(false);
1141*89c4ff92SAndroid Build Coastguard Worker                 bool importOk = inputTensorHandle->Import(mem, m_NetworkProperties.m_OutputSource);
1142*89c4ff92SAndroid Build Coastguard Worker                 tensorHandle->Unmap();
1143*89c4ff92SAndroid Build Coastguard Worker 
1144*89c4ff92SAndroid Build Coastguard Worker                 if (importOk)
1145*89c4ff92SAndroid Build Coastguard Worker                 {
1146*89c4ff92SAndroid Build Coastguard Worker                     // Insert synchronization workload
1147*89c4ff92SAndroid Build Coastguard Worker                     MemSyncQueueDescriptor syncDesc;
1148*89c4ff92SAndroid Build Coastguard Worker                     syncDesc.m_Inputs.push_back(inputTensorHandle);
1149*89c4ff92SAndroid Build Coastguard Worker                     info.m_InputTensorInfos.push_back(inputTensorInfo);
1150*89c4ff92SAndroid Build Coastguard Worker                     auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
1151*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
1152*89c4ff92SAndroid Build Coastguard Worker                     m_OutputQueue.push_back(move(syncWorkload));
1153*89c4ff92SAndroid Build Coastguard Worker                 }
1154*89c4ff92SAndroid Build Coastguard Worker                 else
1155*89c4ff92SAndroid Build Coastguard Worker                 {
1156*89c4ff92SAndroid Build Coastguard Worker                     throw MemoryExportException("EnqueueOutput: Memory Export failed");
1157*89c4ff92SAndroid Build Coastguard Worker                 }
1158*89c4ff92SAndroid Build Coastguard Worker             }
1159*89c4ff92SAndroid Build Coastguard Worker         }
1160*89c4ff92SAndroid Build Coastguard Worker     }
1161*89c4ff92SAndroid Build Coastguard Worker     if (needMemCopy)
1162*89c4ff92SAndroid Build Coastguard Worker     {
1163*89c4ff92SAndroid Build Coastguard Worker         // If we got here then we didn't export the memory, so add an output workload which performs a memcopy.
1164*89c4ff92SAndroid Build Coastguard Worker         outputQueueDescriptor.m_Inputs.push_back(inputTensorHandle);
1165*89c4ff92SAndroid Build Coastguard Worker         info.m_InputTensorInfos.push_back(inputTensorInfo);
1166*89c4ff92SAndroid Build Coastguard Worker 
1167*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<IWorkload> outputWorkload =
1168*89c4ff92SAndroid Build Coastguard Worker             std::make_unique<CopyMemGenericWorkload>(outputQueueDescriptor, info);
1169*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT_MSG(outputWorkload, "No output workload created");
1170*89c4ff92SAndroid Build Coastguard Worker 
1171*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1172*89c4ff92SAndroid Build Coastguard Worker             TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1173*89c4ff92SAndroid Build Coastguard Worker         if (timelineUtils)
1174*89c4ff92SAndroid Build Coastguard Worker         {
1175*89c4ff92SAndroid Build Coastguard Worker             // Add Output Workload to the post-optimisation network structure
1176*89c4ff92SAndroid Build Coastguard Worker             AddWorkloadStructure(timelineUtils, outputWorkload, layer);
1177*89c4ff92SAndroid Build Coastguard Worker             timelineUtils->Commit();
1178*89c4ff92SAndroid Build Coastguard Worker         }
1179*89c4ff92SAndroid Build Coastguard Worker 
1180*89c4ff92SAndroid Build Coastguard Worker         m_OutputQueue.push_back(move(outputWorkload));
1181*89c4ff92SAndroid Build Coastguard Worker     }
1182*89c4ff92SAndroid Build Coastguard Worker }
1183*89c4ff92SAndroid Build Coastguard Worker 
AllocateWorkingMemory(std::lock_guard<std::mutex> & lock)1184*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateWorkingMemory(
1185*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1186*89c4ff92SAndroid Build Coastguard Worker      std::lock_guard<std::mutex>& lock
1187*89c4ff92SAndroid Build Coastguard Worker #endif
1188*89c4ff92SAndroid Build Coastguard Worker     )
1189*89c4ff92SAndroid Build Coastguard Worker {
1190*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Working Memory Allocation");
1191*89c4ff92SAndroid Build Coastguard Worker 
1192*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1193*89c4ff92SAndroid Build Coastguard Worker     // this unused parameter makes sure we can only call this function with a valid lock
1194*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(lock);
1195*89c4ff92SAndroid Build Coastguard Worker #endif
1196*89c4ff92SAndroid Build Coastguard Worker     if (m_IsWorkingMemAllocated)
1197*89c4ff92SAndroid Build Coastguard Worker     {
1198*89c4ff92SAndroid Build Coastguard Worker         return;
1199*89c4ff92SAndroid Build Coastguard Worker     }
1200*89c4ff92SAndroid Build Coastguard Worker 
1201*89c4ff92SAndroid Build Coastguard Worker     if (m_ExternalMemoryManager)
1202*89c4ff92SAndroid Build Coastguard Worker     {
1203*89c4ff92SAndroid Build Coastguard Worker         m_ExternalMemoryManager->Allocate();
1204*89c4ff92SAndroid Build Coastguard Worker 
1205*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < m_TensorMemory.size(); ++i)
1206*89c4ff92SAndroid Build Coastguard Worker         {
1207*89c4ff92SAndroid Build Coastguard Worker             m_Tensorhandles[i]->Import(m_TensorMemory[i].first->m_Data, m_TensorMemory[i].second);
1208*89c4ff92SAndroid Build Coastguard Worker         }
1209*89c4ff92SAndroid Build Coastguard Worker     }
1210*89c4ff92SAndroid Build Coastguard Worker 
1211*89c4ff92SAndroid Build Coastguard Worker     for (auto&& memoryManager : m_BackendMemoryMangers)
1212*89c4ff92SAndroid Build Coastguard Worker     {
1213*89c4ff92SAndroid Build Coastguard Worker         if (memoryManager)
1214*89c4ff92SAndroid Build Coastguard Worker         {
1215*89c4ff92SAndroid Build Coastguard Worker             memoryManager->Acquire();
1216*89c4ff92SAndroid Build Coastguard Worker         }
1217*89c4ff92SAndroid Build Coastguard Worker     }
1218*89c4ff92SAndroid Build Coastguard Worker     m_TensorHandleFactoryRegistry.AquireMemory();
1219*89c4ff92SAndroid Build Coastguard Worker     m_IsWorkingMemAllocated = true;
1220*89c4ff92SAndroid Build Coastguard Worker }
1221*89c4ff92SAndroid Build Coastguard Worker 
FreeWorkingMemory()1222*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::FreeWorkingMemory()
1223*89c4ff92SAndroid Build Coastguard Worker {
1224*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1225*89c4ff92SAndroid Build Coastguard Worker     std::lock_guard<std::mutex> lockGuard(m_WorkingMemMutex);
1226*89c4ff92SAndroid Build Coastguard Worker #endif
1227*89c4ff92SAndroid Build Coastguard Worker 
1228*89c4ff92SAndroid Build Coastguard Worker     if (!m_IsWorkingMemAllocated)
1229*89c4ff92SAndroid Build Coastguard Worker     {
1230*89c4ff92SAndroid Build Coastguard Worker         return;
1231*89c4ff92SAndroid Build Coastguard Worker     }
1232*89c4ff92SAndroid Build Coastguard Worker 
1233*89c4ff92SAndroid Build Coastguard Worker     if (m_ExternalMemoryManager)
1234*89c4ff92SAndroid Build Coastguard Worker     {
1235*89c4ff92SAndroid Build Coastguard Worker         m_ExternalMemoryManager->Deallocate();
1236*89c4ff92SAndroid Build Coastguard Worker     }
1237*89c4ff92SAndroid Build Coastguard Worker 
1238*89c4ff92SAndroid Build Coastguard Worker     // Informs the memory managers to release memory in its respective memory group
1239*89c4ff92SAndroid Build Coastguard Worker     for (auto&& memoryManager : m_BackendMemoryMangers)
1240*89c4ff92SAndroid Build Coastguard Worker     {
1241*89c4ff92SAndroid Build Coastguard Worker         if (memoryManager)
1242*89c4ff92SAndroid Build Coastguard Worker         {
1243*89c4ff92SAndroid Build Coastguard Worker             memoryManager->Release();
1244*89c4ff92SAndroid Build Coastguard Worker         }
1245*89c4ff92SAndroid Build Coastguard Worker     }
1246*89c4ff92SAndroid Build Coastguard Worker     m_TensorHandleFactoryRegistry.ReleaseMemory();
1247*89c4ff92SAndroid Build Coastguard Worker     m_IsWorkingMemAllocated = false;
1248*89c4ff92SAndroid Build Coastguard Worker }
1249*89c4ff92SAndroid Build Coastguard Worker 
Execute(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,ProfilingGuid inferenceGuid)1250*89c4ff92SAndroid Build Coastguard Worker bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
1251*89c4ff92SAndroid Build Coastguard Worker                            ProfilingGuid inferenceGuid)
1252*89c4ff92SAndroid Build Coastguard Worker {
1253*89c4ff92SAndroid Build Coastguard Worker     bool success = true;
1254*89c4ff92SAndroid Build Coastguard Worker 
1255*89c4ff92SAndroid Build Coastguard Worker     auto Fail = [&](const std::exception& error)
1256*89c4ff92SAndroid Build Coastguard Worker     {
1257*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "An error occurred attempting to execute a workload: " << error.what();
1258*89c4ff92SAndroid Build Coastguard Worker         success = false;
1259*89c4ff92SAndroid Build Coastguard Worker     };
1260*89c4ff92SAndroid Build Coastguard Worker 
1261*89c4ff92SAndroid Build Coastguard Worker     try
1262*89c4ff92SAndroid Build Coastguard Worker     {
1263*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1264*89c4ff92SAndroid Build Coastguard Worker         std::lock_guard<std::mutex> lockGuard(m_WorkingMemMutex);
1265*89c4ff92SAndroid Build Coastguard Worker         AllocateWorkingMemory(lockGuard);
1266*89c4ff92SAndroid Build Coastguard Worker #else
1267*89c4ff92SAndroid Build Coastguard Worker         AllocateWorkingMemory();
1268*89c4ff92SAndroid Build Coastguard Worker #endif
1269*89c4ff92SAndroid Build Coastguard Worker 
1270*89c4ff92SAndroid Build Coastguard Worker         ProfilingDynamicGuid workloadInferenceID(0);
1271*89c4ff92SAndroid Build Coastguard Worker         auto ExecuteQueue = [&timelineUtils, &workloadInferenceID, &inferenceGuid](WorkloadQueue& queue)
1272*89c4ff92SAndroid Build Coastguard Worker         {
1273*89c4ff92SAndroid Build Coastguard Worker             for (auto& workload : queue)
1274*89c4ff92SAndroid Build Coastguard Worker             {
1275*89c4ff92SAndroid Build Coastguard Worker                 if(timelineUtils)
1276*89c4ff92SAndroid Build Coastguard Worker                 {
1277*89c4ff92SAndroid Build Coastguard Worker                     workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(),
1278*89c4ff92SAndroid Build Coastguard Worker                                                                                                     inferenceGuid);
1279*89c4ff92SAndroid Build Coastguard Worker                 }
1280*89c4ff92SAndroid Build Coastguard Worker                 workload->Execute();
1281*89c4ff92SAndroid Build Coastguard Worker                 if(timelineUtils)
1282*89c4ff92SAndroid Build Coastguard Worker                 {
1283*89c4ff92SAndroid Build Coastguard Worker                     timelineUtils->RecordEndOfLifeEvent(workloadInferenceID);
1284*89c4ff92SAndroid Build Coastguard Worker                 }
1285*89c4ff92SAndroid Build Coastguard Worker             }
1286*89c4ff92SAndroid Build Coastguard Worker         };
1287*89c4ff92SAndroid Build Coastguard Worker 
1288*89c4ff92SAndroid Build Coastguard Worker         ExecuteQueue(m_InputQueue);
1289*89c4ff92SAndroid Build Coastguard Worker         ExecuteQueue(m_WorkloadQueue);
1290*89c4ff92SAndroid Build Coastguard Worker         ExecuteQueue(m_OutputQueue);
1291*89c4ff92SAndroid Build Coastguard Worker     }
1292*89c4ff92SAndroid Build Coastguard Worker     catch (const RuntimeException& error)
1293*89c4ff92SAndroid Build Coastguard Worker     {
1294*89c4ff92SAndroid Build Coastguard Worker         Fail(error);
1295*89c4ff92SAndroid Build Coastguard Worker     }
1296*89c4ff92SAndroid Build Coastguard Worker     catch (const std::runtime_error& error)
1297*89c4ff92SAndroid Build Coastguard Worker     {
1298*89c4ff92SAndroid Build Coastguard Worker         Fail(error);
1299*89c4ff92SAndroid Build Coastguard Worker     }
1300*89c4ff92SAndroid Build Coastguard Worker 
1301*89c4ff92SAndroid Build Coastguard Worker     return success;
1302*89c4ff92SAndroid Build Coastguard Worker }
1303*89c4ff92SAndroid Build Coastguard Worker 
EnqueueInput(const ConstTensor & inputTensor,ITensorHandle * inputTensorHandle)1304*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueInput(const ConstTensor& inputTensor, ITensorHandle* inputTensorHandle)
1305*89c4ff92SAndroid Build Coastguard Worker {
1306*89c4ff92SAndroid Build Coastguard Worker     if (m_NetworkProperties.m_ImportEnabled)  // Try import the input tensor
1307*89c4ff92SAndroid Build Coastguard Worker     {
1308*89c4ff92SAndroid Build Coastguard Worker         MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
1309*89c4ff92SAndroid Build Coastguard Worker         if (CheckFlag(importFlags, m_NetworkProperties.m_InputSource) )
1310*89c4ff92SAndroid Build Coastguard Worker         {
1311*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<ITensorHandle> tensorHandle =
1312*89c4ff92SAndroid Build Coastguard Worker                     std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(),
1313*89c4ff92SAndroid Build Coastguard Worker                                                                    inputTensor.GetMemoryArea());
1314*89c4ff92SAndroid Build Coastguard Worker             void* mem = tensorHandle->Map(false);
1315*89c4ff92SAndroid Build Coastguard Worker 
1316*89c4ff92SAndroid Build Coastguard Worker             if (inputTensorHandle->Import(mem, m_NetworkProperties.m_InputSource))
1317*89c4ff92SAndroid Build Coastguard Worker             {
1318*89c4ff92SAndroid Build Coastguard Worker                 tensorHandle->Unmap();
1319*89c4ff92SAndroid Build Coastguard Worker                 return;
1320*89c4ff92SAndroid Build Coastguard Worker             }
1321*89c4ff92SAndroid Build Coastguard Worker             tensorHandle->Unmap();
1322*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("EnqueueInput: Memory Import failed");
1323*89c4ff92SAndroid Build Coastguard Worker         }
1324*89c4ff92SAndroid Build Coastguard Worker         else
1325*89c4ff92SAndroid Build Coastguard Worker         {
1326*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("EnqueueInput: Memory Import failed, backend does not support Import");
1327*89c4ff92SAndroid Build Coastguard Worker         }
1328*89c4ff92SAndroid Build Coastguard Worker     }
1329*89c4ff92SAndroid Build Coastguard Worker     else
1330*89c4ff92SAndroid Build Coastguard Worker     {
1331*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyInput");
1332*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<ITensorHandle> tensorHandle =
1333*89c4ff92SAndroid Build Coastguard Worker                 std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(), inputTensor.GetMemoryArea());
1334*89c4ff92SAndroid Build Coastguard Worker 
1335*89c4ff92SAndroid Build Coastguard Worker         auto copyFunc = [](void* dst, const void* src, size_t size)
1336*89c4ff92SAndroid Build Coastguard Worker         {
1337*89c4ff92SAndroid Build Coastguard Worker             memcpy(dst, src, size);
1338*89c4ff92SAndroid Build Coastguard Worker         };
1339*89c4ff92SAndroid Build Coastguard Worker 
1340*89c4ff92SAndroid Build Coastguard Worker         CopyTensorContentsGeneric(tensorHandle.get(), inputTensorHandle, copyFunc);
1341*89c4ff92SAndroid Build Coastguard Worker     }
1342*89c4ff92SAndroid Build Coastguard Worker }
1343*89c4ff92SAndroid Build Coastguard Worker 
1344*89c4ff92SAndroid Build Coastguard Worker // Note: We can only import the output pointer if all of the following  hold true:
1345*89c4ff92SAndroid Build Coastguard Worker // a) The imported pointer is aligned sufficiently
1346*89c4ff92SAndroid Build Coastguard Worker // b) The tensor has zero padding
1347*89c4ff92SAndroid Build Coastguard Worker // c) There is only one connection to the OutputSlot and it is to an OutputLayer.
1348*89c4ff92SAndroid Build Coastguard Worker // d) The output pointer is allocated via malloc. (Other types will be supported in a later release)
1349*89c4ff92SAndroid Build Coastguard Worker // e) m_IsExportEnabled must be set to true
ImportOutputTensor(const Tensor & outputTensor,ITensorHandle * outputTensorHandle)1350*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ImportOutputTensor(const Tensor& outputTensor, ITensorHandle* outputTensorHandle)
1351*89c4ff92SAndroid Build Coastguard Worker {
1352*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(outputTensorHandle != nullptr, "Data should have been allocated.");
1353*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags importFlags = outputTensorHandle->GetImportFlags();
1354*89c4ff92SAndroid Build Coastguard Worker     if (CheckFlag(importFlags, m_NetworkProperties.m_OutputSource))
1355*89c4ff92SAndroid Build Coastguard Worker     {
1356*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<ITensorHandle> tensorHandle =
1357*89c4ff92SAndroid Build Coastguard Worker                 std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(),
1358*89c4ff92SAndroid Build Coastguard Worker                                                           outputTensor.GetMemoryArea());
1359*89c4ff92SAndroid Build Coastguard Worker 
1360*89c4ff92SAndroid Build Coastguard Worker         void* mem = tensorHandle->Map(false);
1361*89c4ff92SAndroid Build Coastguard Worker         bool importOk = outputTensorHandle->Import(mem, m_NetworkProperties.m_OutputSource);
1362*89c4ff92SAndroid Build Coastguard Worker         tensorHandle->Unmap();
1363*89c4ff92SAndroid Build Coastguard Worker 
1364*89c4ff92SAndroid Build Coastguard Worker         if (!importOk)
1365*89c4ff92SAndroid Build Coastguard Worker         {
1366*89c4ff92SAndroid Build Coastguard Worker             throw MemoryExportException("ImportOutputTensor: Memory Export failed");
1367*89c4ff92SAndroid Build Coastguard Worker         }
1368*89c4ff92SAndroid Build Coastguard Worker     }
1369*89c4ff92SAndroid Build Coastguard Worker     else
1370*89c4ff92SAndroid Build Coastguard Worker     {
1371*89c4ff92SAndroid Build Coastguard Worker         throw MemoryExportException("ImportOutputTensor: Memory Export failed, attempting to export Input Layer");
1372*89c4ff92SAndroid Build Coastguard Worker     }
1373*89c4ff92SAndroid Build Coastguard Worker 
1374*89c4ff92SAndroid Build Coastguard Worker }
1375*89c4ff92SAndroid Build Coastguard Worker 
CopyToOutputTensor(const Tensor & outputTensor,ITensorHandle * outputTensorHandle)1376*89c4ff92SAndroid Build Coastguard Worker void CopyToOutputTensor(const Tensor& outputTensor, ITensorHandle* outputTensorHandle)
1377*89c4ff92SAndroid Build Coastguard Worker {
1378*89c4ff92SAndroid Build Coastguard Worker     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyOutput");
1379*89c4ff92SAndroid Build Coastguard Worker     auto copyFunc = [](void* dst, const void* src, size_t size)
1380*89c4ff92SAndroid Build Coastguard Worker     {
1381*89c4ff92SAndroid Build Coastguard Worker         memcpy(dst, src, size);
1382*89c4ff92SAndroid Build Coastguard Worker     };
1383*89c4ff92SAndroid Build Coastguard Worker 
1384*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> tensorHandle =
1385*89c4ff92SAndroid Build Coastguard Worker             std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(),
1386*89c4ff92SAndroid Build Coastguard Worker                                                       outputTensor.GetMemoryArea());
1387*89c4ff92SAndroid Build Coastguard Worker 
1388*89c4ff92SAndroid Build Coastguard Worker     CopyTensorContentsGeneric(outputTensorHandle, tensorHandle.get(), copyFunc);
1389*89c4ff92SAndroid Build Coastguard Worker }
1390*89c4ff92SAndroid Build Coastguard Worker 
1391*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensor(const LayerBindingId layerId,const InputTensors & inputTensors)1392*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor GetInputTensor(const LayerBindingId layerId, const InputTensors& inputTensors)
1393*89c4ff92SAndroid Build Coastguard Worker {
1394*89c4ff92SAndroid Build Coastguard Worker     for (auto inputTensorPair : inputTensors)
1395*89c4ff92SAndroid Build Coastguard Worker     {
1396*89c4ff92SAndroid Build Coastguard Worker         LayerBindingId id = inputTensorPair.first;
1397*89c4ff92SAndroid Build Coastguard Worker         if (id == layerId)
1398*89c4ff92SAndroid Build Coastguard Worker         {
1399*89c4ff92SAndroid Build Coastguard Worker             return inputTensorPair.second;
1400*89c4ff92SAndroid Build Coastguard Worker         }
1401*89c4ff92SAndroid Build Coastguard Worker     }
1402*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException("Input does not exist.");
1403*89c4ff92SAndroid Build Coastguard Worker }
1404*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensor(const LayerBindingId layerId,const OutputTensors & outputTensors)1405*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor GetOutputTensor(const LayerBindingId layerId, const OutputTensors& outputTensors)
1406*89c4ff92SAndroid Build Coastguard Worker {
1407*89c4ff92SAndroid Build Coastguard Worker     for (auto outputTensorPair : outputTensors)
1408*89c4ff92SAndroid Build Coastguard Worker     {
1409*89c4ff92SAndroid Build Coastguard Worker         LayerBindingId id = outputTensorPair.first;
1410*89c4ff92SAndroid Build Coastguard Worker         if (id == layerId)
1411*89c4ff92SAndroid Build Coastguard Worker         {
1412*89c4ff92SAndroid Build Coastguard Worker             return outputTensorPair.second;
1413*89c4ff92SAndroid Build Coastguard Worker         }
1414*89c4ff92SAndroid Build Coastguard Worker     }
1415*89c4ff92SAndroid Build Coastguard Worker     throw InvalidArgumentException("Output does not exist.");
1416*89c4ff92SAndroid Build Coastguard Worker }
1417*89c4ff92SAndroid Build Coastguard Worker 
ImportInputs(const InputTensors & inputTensors,MemorySource forceImportMemorySource)1418*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inputTensors,
1419*89c4ff92SAndroid Build Coastguard Worker                                                          MemorySource forceImportMemorySource)
1420*89c4ff92SAndroid Build Coastguard Worker {
1421*89c4ff92SAndroid Build Coastguard Worker     if (!m_NetworkProperties.m_AsyncEnabled)
1422*89c4ff92SAndroid Build Coastguard Worker     {
1423*89c4ff92SAndroid Build Coastguard Worker         // Cannot import if import is not enabled and forceImportMemorySource is undefined
1424*89c4ff92SAndroid Build Coastguard Worker         if (forceImportMemorySource == MemorySource::Undefined)
1425*89c4ff92SAndroid Build Coastguard Worker         {
1426*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ImportInputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
1427*89c4ff92SAndroid Build Coastguard Worker         }
1428*89c4ff92SAndroid Build Coastguard Worker         // The number of pre imported tensors should not exceed the number of inputs.
1429*89c4ff92SAndroid Build Coastguard Worker         if (inputTensors.size() > m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumInputs())
1430*89c4ff92SAndroid Build Coastguard Worker         {
1431*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ImportInputs: The number of tensors provided exceeds the number of inputs.");
1432*89c4ff92SAndroid Build Coastguard Worker         }
1433*89c4ff92SAndroid Build Coastguard Worker 
1434*89c4ff92SAndroid Build Coastguard Worker         std::vector<ImportedInputId> importedInputs;
1435*89c4ff92SAndroid Build Coastguard Worker         Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1436*89c4ff92SAndroid Build Coastguard Worker         unsigned int inputIndex = 0;
1437*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* inputLayer : graph.GetInputLayers())
1438*89c4ff92SAndroid Build Coastguard Worker         {
1439*89c4ff92SAndroid Build Coastguard Worker             auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
1440*89c4ff92SAndroid Build Coastguard Worker 
1441*89c4ff92SAndroid Build Coastguard Worker             if (!outputTensorHandle)
1442*89c4ff92SAndroid Build Coastguard Worker             {
1443*89c4ff92SAndroid Build Coastguard Worker                 inputIndex++;
1444*89c4ff92SAndroid Build Coastguard Worker                 continue;
1445*89c4ff92SAndroid Build Coastguard Worker             }
1446*89c4ff92SAndroid Build Coastguard Worker 
1447*89c4ff92SAndroid Build Coastguard Worker             auto layerBindingId = inputLayer->GetBindingId();
1448*89c4ff92SAndroid Build Coastguard Worker             auto it = std::find_if(inputTensors.begin(), inputTensors.end(), [=](const auto& inputTensor)
1449*89c4ff92SAndroid Build Coastguard Worker             {
1450*89c4ff92SAndroid Build Coastguard Worker                 return inputTensor.first == layerBindingId;
1451*89c4ff92SAndroid Build Coastguard Worker             });
1452*89c4ff92SAndroid Build Coastguard Worker 
1453*89c4ff92SAndroid Build Coastguard Worker             if (it == inputTensors.end())
1454*89c4ff92SAndroid Build Coastguard Worker             {
1455*89c4ff92SAndroid Build Coastguard Worker                 inputIndex++;
1456*89c4ff92SAndroid Build Coastguard Worker                 continue;
1457*89c4ff92SAndroid Build Coastguard Worker             }
1458*89c4ff92SAndroid Build Coastguard Worker 
1459*89c4ff92SAndroid Build Coastguard Worker             const auto& inputTensor = *it;
1460*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<ITensorHandle> passThroughTensorHandle =
1461*89c4ff92SAndroid Build Coastguard Worker                     std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
1462*89c4ff92SAndroid Build Coastguard Worker                                                                    inputTensor.second.GetMemoryArea());
1463*89c4ff92SAndroid Build Coastguard Worker 
1464*89c4ff92SAndroid Build Coastguard Worker             try
1465*89c4ff92SAndroid Build Coastguard Worker             {
1466*89c4ff92SAndroid Build Coastguard Worker                 if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource)
1467*89c4ff92SAndroid Build Coastguard Worker                     && (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource)))
1468*89c4ff92SAndroid Build Coastguard Worker                 {
1469*89c4ff92SAndroid Build Coastguard Worker                     importedInputs.push_back(inputIndex);
1470*89c4ff92SAndroid Build Coastguard Worker                 }
1471*89c4ff92SAndroid Build Coastguard Worker                 passThroughTensorHandle->Unmap();
1472*89c4ff92SAndroid Build Coastguard Worker             }
1473*89c4ff92SAndroid Build Coastguard Worker             catch(const MemoryImportException& exception)
1474*89c4ff92SAndroid Build Coastguard Worker             {
1475*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "An error occurred attempting to import input_"
1476*89c4ff92SAndroid Build Coastguard Worker                                            << inputIndex << " : " << exception.what();
1477*89c4ff92SAndroid Build Coastguard Worker                 passThroughTensorHandle->Unmap();
1478*89c4ff92SAndroid Build Coastguard Worker             }
1479*89c4ff92SAndroid Build Coastguard Worker             inputIndex++;
1480*89c4ff92SAndroid Build Coastguard Worker         }
1481*89c4ff92SAndroid Build Coastguard Worker 
1482*89c4ff92SAndroid Build Coastguard Worker         return importedInputs;
1483*89c4ff92SAndroid Build Coastguard Worker     }
1484*89c4ff92SAndroid Build Coastguard Worker     else
1485*89c4ff92SAndroid Build Coastguard Worker     {
1486*89c4ff92SAndroid Build Coastguard Worker         // Import when the import of network properties is enabled
1487*89c4ff92SAndroid Build Coastguard Worker         std::vector<ImportedInputId> importedInputs;
1488*89c4ff92SAndroid Build Coastguard Worker         Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1489*89c4ff92SAndroid Build Coastguard Worker 
1490*89c4ff92SAndroid Build Coastguard Worker         for (auto inputTensor : inputTensors)
1491*89c4ff92SAndroid Build Coastguard Worker         {
1492*89c4ff92SAndroid Build Coastguard Worker             auto layerBindingId = inputTensor.first;
1493*89c4ff92SAndroid Build Coastguard Worker             auto it = std::find_if(graph.GetInputLayers().begin(), graph.GetInputLayers().end(), [=](auto* layer)
1494*89c4ff92SAndroid Build Coastguard Worker             {
1495*89c4ff92SAndroid Build Coastguard Worker                 return layer->GetBindingId() == layerBindingId;
1496*89c4ff92SAndroid Build Coastguard Worker             });
1497*89c4ff92SAndroid Build Coastguard Worker 
1498*89c4ff92SAndroid Build Coastguard Worker             if (it == graph.GetInputLayers().end())
1499*89c4ff92SAndroid Build Coastguard Worker             {
1500*89c4ff92SAndroid Build Coastguard Worker                 throw MemoryImportException(fmt::format(
1501*89c4ff92SAndroid Build Coastguard Worker                     "ImportInputs: Memory Import failed, unknown LayerBindingId: {}", layerBindingId));
1502*89c4ff92SAndroid Build Coastguard Worker             }
1503*89c4ff92SAndroid Build Coastguard Worker 
1504*89c4ff92SAndroid Build Coastguard Worker             const Layer* layer = *it;
1505*89c4ff92SAndroid Build Coastguard Worker             if (layer->GetType() != LayerType::Input)
1506*89c4ff92SAndroid Build Coastguard Worker             {
1507*89c4ff92SAndroid Build Coastguard Worker                 throw InvalidArgumentException("ImportInputs: given layer not an InputLayer");
1508*89c4ff92SAndroid Build Coastguard Worker             }
1509*89c4ff92SAndroid Build Coastguard Worker 
1510*89c4ff92SAndroid Build Coastguard Worker             auto& backend = m_Backends.at(layer->GetBackendId());
1511*89c4ff92SAndroid Build Coastguard Worker             if (!HasCapability(BackendOptions::BackendOption{"PreImportIOTensors", true}, backend->GetCapabilities()))
1512*89c4ff92SAndroid Build Coastguard Worker             {
1513*89c4ff92SAndroid Build Coastguard Worker                 std::string er = backend->GetId();
1514*89c4ff92SAndroid Build Coastguard Worker                 er += " does not have PreImportIOTensors capability";
1515*89c4ff92SAndroid Build Coastguard Worker                 throw BackendCapabilityException(er);
1516*89c4ff92SAndroid Build Coastguard Worker             }
1517*89c4ff92SAndroid Build Coastguard Worker 
1518*89c4ff92SAndroid Build Coastguard Worker             const OutputSlot& outputSlot = layer->GetOutputSlots()[0];
1519*89c4ff92SAndroid Build Coastguard Worker 
1520*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
1521*89c4ff92SAndroid Build Coastguard Worker             const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1522*89c4ff92SAndroid Build Coastguard Worker 
1523*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1524*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(handleFactory);
1525*89c4ff92SAndroid Build Coastguard Worker 
1526*89c4ff92SAndroid Build Coastguard Worker             ImportedTensorHandlePin importedTensorHandlePin{layerBindingId,
1527*89c4ff92SAndroid Build Coastguard Worker                                                             handleFactory->CreateTensorHandle(tensorInfo, false)};
1528*89c4ff92SAndroid Build Coastguard Worker 
1529*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
1530*89c4ff92SAndroid Build Coastguard Worker 
1531*89c4ff92SAndroid Build Coastguard Worker             if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
1532*89c4ff92SAndroid Build Coastguard Worker             {
1533*89c4ff92SAndroid Build Coastguard Worker                 throw MemoryImportException(
1534*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("ImportInputs: Memory Import failed, backend: "
1535*89c4ff92SAndroid Build Coastguard Worker                                 "{} does not support importing from source {}"
1536*89c4ff92SAndroid Build Coastguard Worker                                 , factoryId, m_NetworkProperties.m_InputSource));
1537*89c4ff92SAndroid Build Coastguard Worker             }
1538*89c4ff92SAndroid Build Coastguard Worker 
1539*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<ITensorHandle> passThroughTensorHandle =
1540*89c4ff92SAndroid Build Coastguard Worker                     std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
1541*89c4ff92SAndroid Build Coastguard Worker                                                                    inputTensor.second.GetMemoryArea());
1542*89c4ff92SAndroid Build Coastguard Worker 
1543*89c4ff92SAndroid Build Coastguard Worker             if (tensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))
1544*89c4ff92SAndroid Build Coastguard Worker             {
1545*89c4ff92SAndroid Build Coastguard Worker                 importedInputs.push_back(m_CurImportedInputId++);
1546*89c4ff92SAndroid Build Coastguard Worker                 passThroughTensorHandle->Unmap();
1547*89c4ff92SAndroid Build Coastguard Worker             }
1548*89c4ff92SAndroid Build Coastguard Worker             else
1549*89c4ff92SAndroid Build Coastguard Worker             {
1550*89c4ff92SAndroid Build Coastguard Worker                 passThroughTensorHandle->Unmap();
1551*89c4ff92SAndroid Build Coastguard Worker                 throw MemoryImportException("ImportInputs: Memory Import failed");
1552*89c4ff92SAndroid Build Coastguard Worker             }
1553*89c4ff92SAndroid Build Coastguard Worker 
1554*89c4ff92SAndroid Build Coastguard Worker             m_PreImportedInputHandles.push_back(std::move(importedTensorHandlePin));
1555*89c4ff92SAndroid Build Coastguard Worker         }
1556*89c4ff92SAndroid Build Coastguard Worker         return importedInputs;
1557*89c4ff92SAndroid Build Coastguard Worker     }
1558*89c4ff92SAndroid Build Coastguard Worker }
1559*89c4ff92SAndroid Build Coastguard Worker 
ImportOutputs(const OutputTensors & outputTensors,MemorySource forceImportMemorySource)1560*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& outputTensors,
1561*89c4ff92SAndroid Build Coastguard Worker                                                            MemorySource forceImportMemorySource)
1562*89c4ff92SAndroid Build Coastguard Worker {
1563*89c4ff92SAndroid Build Coastguard Worker     if (!m_NetworkProperties.m_AsyncEnabled)
1564*89c4ff92SAndroid Build Coastguard Worker     {
1565*89c4ff92SAndroid Build Coastguard Worker         // Cannot import if import is not enabled and forceImportMemorySource is undefined
1566*89c4ff92SAndroid Build Coastguard Worker         if (forceImportMemorySource == MemorySource::Undefined)
1567*89c4ff92SAndroid Build Coastguard Worker         {
1568*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ImportOutputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
1569*89c4ff92SAndroid Build Coastguard Worker         }
1570*89c4ff92SAndroid Build Coastguard Worker         // If forceImportMemorySource is defined, try import if memory is aligned
1571*89c4ff92SAndroid Build Coastguard Worker         if (outputTensors.size() != m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumOutputs())
1572*89c4ff92SAndroid Build Coastguard Worker         {
1573*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ImportOutputs: Force Import failed, incorrect number of tensors");
1574*89c4ff92SAndroid Build Coastguard Worker         }
1575*89c4ff92SAndroid Build Coastguard Worker         std::vector<ImportedOutputId> importedOutputs;
1576*89c4ff92SAndroid Build Coastguard Worker         Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1577*89c4ff92SAndroid Build Coastguard Worker 
1578*89c4ff92SAndroid Build Coastguard Worker         unsigned int outputIndex = 0;
1579*89c4ff92SAndroid Build Coastguard Worker         for (const BindableLayer* const outputLayer : graph.GetOutputLayers())
1580*89c4ff92SAndroid Build Coastguard Worker         {
1581*89c4ff92SAndroid Build Coastguard Worker             auto inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
1582*89c4ff92SAndroid Build Coastguard Worker             if (!inputTensorHandle)
1583*89c4ff92SAndroid Build Coastguard Worker             {
1584*89c4ff92SAndroid Build Coastguard Worker                 outputIndex++;
1585*89c4ff92SAndroid Build Coastguard Worker                 continue;
1586*89c4ff92SAndroid Build Coastguard Worker             }
1587*89c4ff92SAndroid Build Coastguard Worker 
1588*89c4ff92SAndroid Build Coastguard Worker             auto layerBindingId = outputLayer->GetBindingId();
1589*89c4ff92SAndroid Build Coastguard Worker             auto it = std::find_if(outputTensors.begin(), outputTensors.end(), [=] (const auto& outputTensor)
1590*89c4ff92SAndroid Build Coastguard Worker             {
1591*89c4ff92SAndroid Build Coastguard Worker                 return outputTensor.first == layerBindingId;
1592*89c4ff92SAndroid Build Coastguard Worker             });
1593*89c4ff92SAndroid Build Coastguard Worker 
1594*89c4ff92SAndroid Build Coastguard Worker             if (it == outputTensors.end())
1595*89c4ff92SAndroid Build Coastguard Worker             {
1596*89c4ff92SAndroid Build Coastguard Worker                 outputIndex++;
1597*89c4ff92SAndroid Build Coastguard Worker                 continue;
1598*89c4ff92SAndroid Build Coastguard Worker             }
1599*89c4ff92SAndroid Build Coastguard Worker 
1600*89c4ff92SAndroid Build Coastguard Worker             const auto outputTensor = *it;
1601*89c4ff92SAndroid Build Coastguard Worker             try
1602*89c4ff92SAndroid Build Coastguard Worker             {
1603*89c4ff92SAndroid Build Coastguard Worker                 // Check if the output memory can be imported
1604*89c4ff92SAndroid Build Coastguard Worker                 if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource)
1605*89c4ff92SAndroid Build Coastguard Worker                     && inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
1606*89c4ff92SAndroid Build Coastguard Worker                 {
1607*89c4ff92SAndroid Build Coastguard Worker                     importedOutputs.push_back(outputIndex);
1608*89c4ff92SAndroid Build Coastguard Worker                 }
1609*89c4ff92SAndroid Build Coastguard Worker             }
1610*89c4ff92SAndroid Build Coastguard Worker             catch(const MemoryImportException& exception)
1611*89c4ff92SAndroid Build Coastguard Worker             {
1612*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_LOG(error) << "An error occurred attempting to import output_"
1613*89c4ff92SAndroid Build Coastguard Worker                                  << outputIndex << " : " << exception.what();
1614*89c4ff92SAndroid Build Coastguard Worker             }
1615*89c4ff92SAndroid Build Coastguard Worker             outputIndex++;
1616*89c4ff92SAndroid Build Coastguard Worker         }
1617*89c4ff92SAndroid Build Coastguard Worker         return importedOutputs;
1618*89c4ff92SAndroid Build Coastguard Worker     }
1619*89c4ff92SAndroid Build Coastguard Worker 
1620*89c4ff92SAndroid Build Coastguard Worker     std::vector<ImportedOutputId> importedOutputs;
1621*89c4ff92SAndroid Build Coastguard Worker     Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1622*89c4ff92SAndroid Build Coastguard Worker 
1623*89c4ff92SAndroid Build Coastguard Worker     for (const auto& outputTensor : outputTensors)
1624*89c4ff92SAndroid Build Coastguard Worker     {
1625*89c4ff92SAndroid Build Coastguard Worker         auto layerBindingId = outputTensor.first;
1626*89c4ff92SAndroid Build Coastguard Worker         auto it = std::find_if(graph.GetOutputLayers().begin(), graph.GetOutputLayers().end(), [=](auto* layer)
1627*89c4ff92SAndroid Build Coastguard Worker         {
1628*89c4ff92SAndroid Build Coastguard Worker             return layer->GetBindingId() == layerBindingId;
1629*89c4ff92SAndroid Build Coastguard Worker         });
1630*89c4ff92SAndroid Build Coastguard Worker 
1631*89c4ff92SAndroid Build Coastguard Worker         if (it == graph.GetOutputLayers().end())
1632*89c4ff92SAndroid Build Coastguard Worker         {
1633*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException(fmt::format("ImportOutputs: Memory Import failed, unknown LayerBindingId: {}",
1634*89c4ff92SAndroid Build Coastguard Worker                                                      layerBindingId));
1635*89c4ff92SAndroid Build Coastguard Worker         }
1636*89c4ff92SAndroid Build Coastguard Worker 
1637*89c4ff92SAndroid Build Coastguard Worker         const Layer* layer = *it;
1638*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() != LayerType::Output)
1639*89c4ff92SAndroid Build Coastguard Worker         {
1640*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("ImportOutputs: given layer not an OutputLayer");
1641*89c4ff92SAndroid Build Coastguard Worker         }
1642*89c4ff92SAndroid Build Coastguard Worker 
1643*89c4ff92SAndroid Build Coastguard Worker         auto& backend = m_Backends.at(layer->GetBackendId());
1644*89c4ff92SAndroid Build Coastguard Worker         if (!HasCapability(BackendOptions::BackendOption{"PreImportIOTensors", true}, backend->GetCapabilities()))
1645*89c4ff92SAndroid Build Coastguard Worker         {
1646*89c4ff92SAndroid Build Coastguard Worker             std::string er = backend->GetId();
1647*89c4ff92SAndroid Build Coastguard Worker             er += " does not have PreImportIOTensors capability";
1648*89c4ff92SAndroid Build Coastguard Worker             throw BackendCapabilityException(er);
1649*89c4ff92SAndroid Build Coastguard Worker         }
1650*89c4ff92SAndroid Build Coastguard Worker 
1651*89c4ff92SAndroid Build Coastguard Worker         const InputSlot& inputSlot = layer->GetInputSlots()[0];
1652*89c4ff92SAndroid Build Coastguard Worker         ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId();
1653*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo();
1654*89c4ff92SAndroid Build Coastguard Worker 
1655*89c4ff92SAndroid Build Coastguard Worker         ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1656*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(handleFactory);
1657*89c4ff92SAndroid Build Coastguard Worker 
1658*89c4ff92SAndroid Build Coastguard Worker         ImportedTensorHandlePin importedTensorHandlePin{layerBindingId,
1659*89c4ff92SAndroid Build Coastguard Worker                                                         handleFactory->CreateTensorHandle(tensorInfo, false)};
1660*89c4ff92SAndroid Build Coastguard Worker 
1661*89c4ff92SAndroid Build Coastguard Worker         ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
1662*89c4ff92SAndroid Build Coastguard Worker 
1663*89c4ff92SAndroid Build Coastguard Worker         if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
1664*89c4ff92SAndroid Build Coastguard Worker         {
1665*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException(fmt::format("ImportInputs: Memory Import failed, backend: "
1666*89c4ff92SAndroid Build Coastguard Worker                                                     "{} does not support importing from source {}"
1667*89c4ff92SAndroid Build Coastguard Worker                                                     , factoryId, forceImportMemorySource));
1668*89c4ff92SAndroid Build Coastguard Worker         }
1669*89c4ff92SAndroid Build Coastguard Worker 
1670*89c4ff92SAndroid Build Coastguard Worker         if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
1671*89c4ff92SAndroid Build Coastguard Worker         {
1672*89c4ff92SAndroid Build Coastguard Worker             importedOutputs.push_back(m_CurImportedOutputId++);
1673*89c4ff92SAndroid Build Coastguard Worker         }
1674*89c4ff92SAndroid Build Coastguard Worker         else
1675*89c4ff92SAndroid Build Coastguard Worker         {
1676*89c4ff92SAndroid Build Coastguard Worker             throw MemoryImportException("ImportInputs: Memory Import failed");
1677*89c4ff92SAndroid Build Coastguard Worker         }
1678*89c4ff92SAndroid Build Coastguard Worker 
1679*89c4ff92SAndroid Build Coastguard Worker         m_PreImportedOutputHandles.push_back(std::move(importedTensorHandlePin));
1680*89c4ff92SAndroid Build Coastguard Worker     }
1681*89c4ff92SAndroid Build Coastguard Worker 
1682*89c4ff92SAndroid Build Coastguard Worker     return importedOutputs;
1683*89c4ff92SAndroid Build Coastguard Worker }
1684*89c4ff92SAndroid Build Coastguard Worker 
ClearImportedInputs(const std::vector<ImportedInputId> inputIds)1685*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ClearImportedInputs(const std::vector<ImportedInputId> inputIds)
1686*89c4ff92SAndroid Build Coastguard Worker {
1687*89c4ff92SAndroid Build Coastguard Worker     for (auto id : inputIds)
1688*89c4ff92SAndroid Build Coastguard Worker     {
1689*89c4ff92SAndroid Build Coastguard Worker         if (id > m_PreImportedInputHandles.size())
1690*89c4ff92SAndroid Build Coastguard Worker         {
1691*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(fmt::format("ClearImportedInputs::Unknown ImportedInputId: {}", id));
1692*89c4ff92SAndroid Build Coastguard Worker         }
1693*89c4ff92SAndroid Build Coastguard Worker 
1694*89c4ff92SAndroid Build Coastguard Worker         auto& importedTensorHandle = m_PreImportedInputHandles[id].m_TensorHandle;
1695*89c4ff92SAndroid Build Coastguard Worker         if (!importedTensorHandle)
1696*89c4ff92SAndroid Build Coastguard Worker         {
1697*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(
1698*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("ClearImportedInputs::ImportedInput with id: {} has already been deleted", id));
1699*89c4ff92SAndroid Build Coastguard Worker         }
1700*89c4ff92SAndroid Build Coastguard Worker         // Call Unimport then destroy the tensorHandle
1701*89c4ff92SAndroid Build Coastguard Worker         importedTensorHandle->Unimport();
1702*89c4ff92SAndroid Build Coastguard Worker         importedTensorHandle = {};
1703*89c4ff92SAndroid Build Coastguard Worker     }
1704*89c4ff92SAndroid Build Coastguard Worker }
1705*89c4ff92SAndroid Build Coastguard Worker 
ClearImportedOutputs(const std::vector<ImportedOutputId> outputIds)1706*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ClearImportedOutputs(const std::vector<ImportedOutputId> outputIds)
1707*89c4ff92SAndroid Build Coastguard Worker {
1708*89c4ff92SAndroid Build Coastguard Worker     for (auto id : outputIds)
1709*89c4ff92SAndroid Build Coastguard Worker     {
1710*89c4ff92SAndroid Build Coastguard Worker         if (id > m_PreImportedOutputHandles.size())
1711*89c4ff92SAndroid Build Coastguard Worker         {
1712*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(fmt::format("ClearImportedOutputs::Unknown ImportedOutputId: {}", id));
1713*89c4ff92SAndroid Build Coastguard Worker         }
1714*89c4ff92SAndroid Build Coastguard Worker 
1715*89c4ff92SAndroid Build Coastguard Worker         auto& importedTensorHandle = m_PreImportedOutputHandles[id].m_TensorHandle;
1716*89c4ff92SAndroid Build Coastguard Worker         if (!importedTensorHandle)
1717*89c4ff92SAndroid Build Coastguard Worker         {
1718*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(
1719*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("ClearImportedOutputs::ImportedOutput with id: {} has already been deleted", id));
1720*89c4ff92SAndroid Build Coastguard Worker         }
1721*89c4ff92SAndroid Build Coastguard Worker         // Call Unimport then destroy the tensorHandle
1722*89c4ff92SAndroid Build Coastguard Worker         importedTensorHandle->Unimport();
1723*89c4ff92SAndroid Build Coastguard Worker         importedTensorHandle = {};
1724*89c4ff92SAndroid Build Coastguard Worker     }
1725*89c4ff92SAndroid Build Coastguard Worker }
1726*89c4ff92SAndroid Build Coastguard Worker 
Execute(const InputTensors & inputTensors,const OutputTensors & outputTensors,IWorkingMemHandle & iWorkingMemHandle,std::vector<ImportedInputId> preImportedInputs,std::vector<ImportedOutputId> preImportedOutputs)1727*89c4ff92SAndroid Build Coastguard Worker Status LoadedNetwork::Execute(const InputTensors& inputTensors,
1728*89c4ff92SAndroid Build Coastguard Worker                               const OutputTensors& outputTensors,
1729*89c4ff92SAndroid Build Coastguard Worker                               IWorkingMemHandle& iWorkingMemHandle,
1730*89c4ff92SAndroid Build Coastguard Worker                               std::vector<ImportedInputId> preImportedInputs,
1731*89c4ff92SAndroid Build Coastguard Worker                               std::vector<ImportedOutputId> preImportedOutputs)
1732*89c4ff92SAndroid Build Coastguard Worker {
1733*89c4ff92SAndroid Build Coastguard Worker     const Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
1734*89c4ff92SAndroid Build Coastguard Worker 
1735*89c4ff92SAndroid Build Coastguard Worker     if (inputTensors.size() + preImportedInputs.size() != graph.GetNumInputs())
1736*89c4ff92SAndroid Build Coastguard Worker     {
1737*89c4ff92SAndroid Build Coastguard Worker         if (preImportedInputs.empty())
1738*89c4ff92SAndroid Build Coastguard Worker         {
1739*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("LoadedNetwork::Execute: Number of inputs provided does not match network.");
1740*89c4ff92SAndroid Build Coastguard Worker         }
1741*89c4ff92SAndroid Build Coastguard Worker         else
1742*89c4ff92SAndroid Build Coastguard Worker         {
1743*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("LoadedNetwork::Execute: "
1744*89c4ff92SAndroid Build Coastguard Worker                                            "Number of inputs + preImportedInputs provided does not match network.");
1745*89c4ff92SAndroid Build Coastguard Worker         }
1746*89c4ff92SAndroid Build Coastguard Worker     }
1747*89c4ff92SAndroid Build Coastguard Worker 
1748*89c4ff92SAndroid Build Coastguard Worker     if (outputTensors.size() + preImportedOutputs.size() != graph.GetNumOutputs())
1749*89c4ff92SAndroid Build Coastguard Worker     {
1750*89c4ff92SAndroid Build Coastguard Worker         if (preImportedOutputs.empty())
1751*89c4ff92SAndroid Build Coastguard Worker         {
1752*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("LoadedNetwork::Execute: "
1753*89c4ff92SAndroid Build Coastguard Worker                                            "Number of outputs provided does not match network.");
1754*89c4ff92SAndroid Build Coastguard Worker         }
1755*89c4ff92SAndroid Build Coastguard Worker         else
1756*89c4ff92SAndroid Build Coastguard Worker         {
1757*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("LoadedNetwork::Execute: "
1758*89c4ff92SAndroid Build Coastguard Worker                                            "Number of outputs + preImportedOutputs provided does not match network.");
1759*89c4ff92SAndroid Build Coastguard Worker         }
1760*89c4ff92SAndroid Build Coastguard Worker     }
1761*89c4ff92SAndroid Build Coastguard Worker 
1762*89c4ff92SAndroid Build Coastguard Worker     WorkingMemHandle& workingMemHandle = dynamic_cast<WorkingMemHandle&>(iWorkingMemHandle);
1763*89c4ff92SAndroid Build Coastguard Worker     // Collect all the given LayerBindingIds and check them for duplicates and unknowns.
1764*89c4ff92SAndroid Build Coastguard Worker     std::vector<LayerBindingId>& bindingIds = workingMemHandle.GetBindingIdVector();
1765*89c4ff92SAndroid Build Coastguard Worker     unsigned int index = 0;
1766*89c4ff92SAndroid Build Coastguard Worker     for (auto pair : inputTensors)
1767*89c4ff92SAndroid Build Coastguard Worker     {
1768*89c4ff92SAndroid Build Coastguard Worker         bindingIds[index++] = pair.first;
1769*89c4ff92SAndroid Build Coastguard Worker     }
1770*89c4ff92SAndroid Build Coastguard Worker     for (ImportedInputId id : preImportedInputs)
1771*89c4ff92SAndroid Build Coastguard Worker     {
1772*89c4ff92SAndroid Build Coastguard Worker         bindingIds[index++] = ValidateImportedInputID(id);
1773*89c4ff92SAndroid Build Coastguard Worker     }
1774*89c4ff92SAndroid Build Coastguard Worker     for (auto pair : outputTensors)
1775*89c4ff92SAndroid Build Coastguard Worker     {
1776*89c4ff92SAndroid Build Coastguard Worker         bindingIds[index++] = pair.first;
1777*89c4ff92SAndroid Build Coastguard Worker     }
1778*89c4ff92SAndroid Build Coastguard Worker     for (ImportedOutputId id : preImportedOutputs)
1779*89c4ff92SAndroid Build Coastguard Worker     {
1780*89c4ff92SAndroid Build Coastguard Worker         bindingIds[index++] = ValidateImportedOutputID(id);
1781*89c4ff92SAndroid Build Coastguard Worker     }
1782*89c4ff92SAndroid Build Coastguard Worker 
1783*89c4ff92SAndroid Build Coastguard Worker     workingMemHandle.ValidateBindingIds();
1784*89c4ff92SAndroid Build Coastguard Worker 
1785*89c4ff92SAndroid Build Coastguard Worker     auto resetMemHandle = [&]()
1786*89c4ff92SAndroid Build Coastguard Worker     {
1787*89c4ff92SAndroid Build Coastguard Worker         for (ImportedInputId id: preImportedInputs)
1788*89c4ff92SAndroid Build Coastguard Worker         {
1789*89c4ff92SAndroid Build Coastguard Worker             const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
1790*89c4ff92SAndroid Build Coastguard Worker 
1791*89c4ff92SAndroid Build Coastguard Worker             auto inputHandle = workingMemHandle.GetInputHandle(layerBindingId);
1792*89c4ff92SAndroid Build Coastguard Worker             auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
1793*89c4ff92SAndroid Build Coastguard Worker             for (auto it : inputConnections)
1794*89c4ff92SAndroid Build Coastguard Worker             {
1795*89c4ff92SAndroid Build Coastguard Worker                 *it = inputHandle;
1796*89c4ff92SAndroid Build Coastguard Worker             }
1797*89c4ff92SAndroid Build Coastguard Worker         }
1798*89c4ff92SAndroid Build Coastguard Worker 
1799*89c4ff92SAndroid Build Coastguard Worker         for (ImportedOutputId id: preImportedOutputs)
1800*89c4ff92SAndroid Build Coastguard Worker         {
1801*89c4ff92SAndroid Build Coastguard Worker             const LayerBindingId layerBindingId = m_PreImportedOutputHandles[id].m_LayerBindingId;
1802*89c4ff92SAndroid Build Coastguard Worker 
1803*89c4ff92SAndroid Build Coastguard Worker             auto outputHandle = workingMemHandle.GetOutputHandle(layerBindingId);
1804*89c4ff92SAndroid Build Coastguard Worker             auto outputConnections = workingMemHandle.GetOutputConnection(layerBindingId);
1805*89c4ff92SAndroid Build Coastguard Worker 
1806*89c4ff92SAndroid Build Coastguard Worker             for (auto it : outputConnections)
1807*89c4ff92SAndroid Build Coastguard Worker             {
1808*89c4ff92SAndroid Build Coastguard Worker                 *it = outputHandle;
1809*89c4ff92SAndroid Build Coastguard Worker             }
1810*89c4ff92SAndroid Build Coastguard Worker         }
1811*89c4ff92SAndroid Build Coastguard Worker     };
1812*89c4ff92SAndroid Build Coastguard Worker 
1813*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1814*89c4ff92SAndroid Build Coastguard Worker            TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1815*89c4ff92SAndroid Build Coastguard Worker     ProfilingGuid inferenceGuid = m_ProfilingService->GetNextGuid();
1816*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
1817*89c4ff92SAndroid Build Coastguard Worker     {
1818*89c4ff92SAndroid Build Coastguard Worker         // Add inference timeline trace if profiling is enabled.
1819*89c4ff92SAndroid Build Coastguard Worker        ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
1820*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateTypedEntity(inferenceGuid,LabelsAndEventClasses::INFERENCE_GUID);
1821*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
1822*89c4ff92SAndroid Build Coastguard Worker                                           networkGuid,
1823*89c4ff92SAndroid Build Coastguard Worker                                           inferenceGuid,
1824*89c4ff92SAndroid Build Coastguard Worker                                          LabelsAndEventClasses::EXECUTION_OF_GUID);
1825*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->RecordEvent(inferenceGuid,LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
1826*89c4ff92SAndroid Build Coastguard Worker     }
1827*89c4ff92SAndroid Build Coastguard Worker 
1828*89c4ff92SAndroid Build Coastguard Worker     bool executionSucceeded = true;
1829*89c4ff92SAndroid Build Coastguard Worker 
1830*89c4ff92SAndroid Build Coastguard Worker     if (timelineUtils)
1831*89c4ff92SAndroid Build Coastguard Worker     {
1832*89c4ff92SAndroid Build Coastguard Worker         // Add end of life of the inference timeline if profiling is enabled.
1833*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->RecordEvent(inferenceGuid,LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
1834*89c4ff92SAndroid Build Coastguard Worker         timelineUtils->Commit();
1835*89c4ff92SAndroid Build Coastguard Worker     }
1836*89c4ff92SAndroid Build Coastguard Worker 
1837*89c4ff92SAndroid Build Coastguard Worker     if (!workingMemHandle.IsAllocated())
1838*89c4ff92SAndroid Build Coastguard Worker     {
1839*89c4ff92SAndroid Build Coastguard Worker         workingMemHandle.Allocate();
1840*89c4ff92SAndroid Build Coastguard Worker     }
1841*89c4ff92SAndroid Build Coastguard Worker 
1842*89c4ff92SAndroid Build Coastguard Worker     {
1843*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs");
1844*89c4ff92SAndroid Build Coastguard Worker         for (auto pair : inputTensors)
1845*89c4ff92SAndroid Build Coastguard Worker         {
1846*89c4ff92SAndroid Build Coastguard Worker             EnqueueInput(pair.second, workingMemHandle.GetInputHandle(pair.first));
1847*89c4ff92SAndroid Build Coastguard Worker         }
1848*89c4ff92SAndroid Build Coastguard Worker 
1849*89c4ff92SAndroid Build Coastguard Worker         // Swap in the pre-imported inputs if any
1850*89c4ff92SAndroid Build Coastguard Worker         for (ImportedInputId id : preImportedInputs)
1851*89c4ff92SAndroid Build Coastguard Worker         {
1852*89c4ff92SAndroid Build Coastguard Worker             const ImportedTensorHandlePin& importedInputPin = m_PreImportedInputHandles[id];
1853*89c4ff92SAndroid Build Coastguard Worker             const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
1854*89c4ff92SAndroid Build Coastguard Worker             const auto& preimportedHandle = importedInputPin.m_TensorHandle;
1855*89c4ff92SAndroid Build Coastguard Worker 
1856*89c4ff92SAndroid Build Coastguard Worker             auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
1857*89c4ff92SAndroid Build Coastguard Worker             for (auto it : inputConnections)
1858*89c4ff92SAndroid Build Coastguard Worker             {
1859*89c4ff92SAndroid Build Coastguard Worker                 *it = preimportedHandle.get();
1860*89c4ff92SAndroid Build Coastguard Worker             }
1861*89c4ff92SAndroid Build Coastguard Worker         }
1862*89c4ff92SAndroid Build Coastguard Worker     }
1863*89c4ff92SAndroid Build Coastguard Worker     {
1864*89c4ff92SAndroid Build Coastguard Worker         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs");
1865*89c4ff92SAndroid Build Coastguard Worker         if (m_NetworkProperties.m_ExportEnabled)
1866*89c4ff92SAndroid Build Coastguard Worker         {
1867*89c4ff92SAndroid Build Coastguard Worker             for (auto pair: outputTensors)
1868*89c4ff92SAndroid Build Coastguard Worker             {
1869*89c4ff92SAndroid Build Coastguard Worker                 ImportOutputTensor(pair.second, workingMemHandle.GetOutputHandle(pair.first));
1870*89c4ff92SAndroid Build Coastguard Worker             }
1871*89c4ff92SAndroid Build Coastguard Worker         }
1872*89c4ff92SAndroid Build Coastguard Worker 
1873*89c4ff92SAndroid Build Coastguard Worker         for (ImportedOutputId id : preImportedOutputs)
1874*89c4ff92SAndroid Build Coastguard Worker         {
1875*89c4ff92SAndroid Build Coastguard Worker             const ImportedTensorHandlePin& importedOutputPin = m_PreImportedOutputHandles[id];
1876*89c4ff92SAndroid Build Coastguard Worker             const LayerBindingId layerBindingId = m_PreImportedOutputHandles[id].m_LayerBindingId;
1877*89c4ff92SAndroid Build Coastguard Worker             const auto& preimportedHandle = importedOutputPin.m_TensorHandle;
1878*89c4ff92SAndroid Build Coastguard Worker 
1879*89c4ff92SAndroid Build Coastguard Worker             auto outputConnections = workingMemHandle.GetOutputConnection(layerBindingId);
1880*89c4ff92SAndroid Build Coastguard Worker             for (auto it : outputConnections)
1881*89c4ff92SAndroid Build Coastguard Worker             {
1882*89c4ff92SAndroid Build Coastguard Worker                 *it = preimportedHandle.get();
1883*89c4ff92SAndroid Build Coastguard Worker             }
1884*89c4ff92SAndroid Build Coastguard Worker         }
1885*89c4ff92SAndroid Build Coastguard Worker     }
1886*89c4ff92SAndroid Build Coastguard Worker 
1887*89c4ff92SAndroid Build Coastguard Worker     auto Fail = [&](const std::exception& error)
1888*89c4ff92SAndroid Build Coastguard Worker     {
1889*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << "An error occurred attempting to execute a workload: " << error.what();
1890*89c4ff92SAndroid Build Coastguard Worker         executionSucceeded = false;
1891*89c4ff92SAndroid Build Coastguard Worker     };
1892*89c4ff92SAndroid Build Coastguard Worker     ProfilingDynamicGuid workloadInferenceID(0);
1893*89c4ff92SAndroid Build Coastguard Worker 
1894*89c4ff92SAndroid Build Coastguard Worker     try
1895*89c4ff92SAndroid Build Coastguard Worker     {
1896*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < m_WorkloadQueue.size(); ++i)
1897*89c4ff92SAndroid Build Coastguard Worker         {
1898*89c4ff92SAndroid Build Coastguard Worker             auto& workload = m_WorkloadQueue[i];
1899*89c4ff92SAndroid Build Coastguard Worker             if (timelineUtils)
1900*89c4ff92SAndroid Build Coastguard Worker             {
1901*89c4ff92SAndroid Build Coastguard Worker                 workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(),
1902*89c4ff92SAndroid Build Coastguard Worker                                                                                                 inferenceGuid);
1903*89c4ff92SAndroid Build Coastguard Worker             }
1904*89c4ff92SAndroid Build Coastguard Worker 
1905*89c4ff92SAndroid Build Coastguard Worker             workload->ExecuteAsync(workingMemHandle.GetExecutionDataAt(i).second);
1906*89c4ff92SAndroid Build Coastguard Worker 
1907*89c4ff92SAndroid Build Coastguard Worker             if (timelineUtils)
1908*89c4ff92SAndroid Build Coastguard Worker             {
1909*89c4ff92SAndroid Build Coastguard Worker                 timelineUtils->RecordEndOfLifeEvent(workloadInferenceID);
1910*89c4ff92SAndroid Build Coastguard Worker             }
1911*89c4ff92SAndroid Build Coastguard Worker         }
1912*89c4ff92SAndroid Build Coastguard Worker     }
1913*89c4ff92SAndroid Build Coastguard Worker     catch (const RuntimeException& error)
1914*89c4ff92SAndroid Build Coastguard Worker     {
1915*89c4ff92SAndroid Build Coastguard Worker         resetMemHandle();
1916*89c4ff92SAndroid Build Coastguard Worker         Fail(error);
1917*89c4ff92SAndroid Build Coastguard Worker     }
1918*89c4ff92SAndroid Build Coastguard Worker     catch (const std::runtime_error& error)
1919*89c4ff92SAndroid Build Coastguard Worker     {
1920*89c4ff92SAndroid Build Coastguard Worker         resetMemHandle();
1921*89c4ff92SAndroid Build Coastguard Worker         Fail(error);
1922*89c4ff92SAndroid Build Coastguard Worker     }
1923*89c4ff92SAndroid Build Coastguard Worker     catch (...)
1924*89c4ff92SAndroid Build Coastguard Worker     {
1925*89c4ff92SAndroid Build Coastguard Worker         resetMemHandle();
1926*89c4ff92SAndroid Build Coastguard Worker         throw;
1927*89c4ff92SAndroid Build Coastguard Worker     }
1928*89c4ff92SAndroid Build Coastguard Worker 
1929*89c4ff92SAndroid Build Coastguard Worker     if (!m_NetworkProperties.m_ExportEnabled)
1930*89c4ff92SAndroid Build Coastguard Worker     {
1931*89c4ff92SAndroid Build Coastguard Worker         for (auto pair: outputTensors)
1932*89c4ff92SAndroid Build Coastguard Worker         {
1933*89c4ff92SAndroid Build Coastguard Worker             CopyToOutputTensor(pair.second, workingMemHandle.GetOutputHandle(pair.first));
1934*89c4ff92SAndroid Build Coastguard Worker         }
1935*89c4ff92SAndroid Build Coastguard Worker     }
1936*89c4ff92SAndroid Build Coastguard Worker     else
1937*89c4ff92SAndroid Build Coastguard Worker     {
1938*89c4ff92SAndroid Build Coastguard Worker        ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "SyncMemGeneric_Execute");
1939*89c4ff92SAndroid Build Coastguard Worker        workingMemHandle.MemSyncOutputs();
1940*89c4ff92SAndroid Build Coastguard Worker     }
1941*89c4ff92SAndroid Build Coastguard Worker 
1942*89c4ff92SAndroid Build Coastguard Worker     resetMemHandle();
1943*89c4ff92SAndroid Build Coastguard Worker 
1944*89c4ff92SAndroid Build Coastguard Worker     return executionSucceeded ? Status::Success : Status::Failure;
1945*89c4ff92SAndroid Build Coastguard Worker }
1946*89c4ff92SAndroid Build Coastguard Worker 
1947*89c4ff92SAndroid Build Coastguard Worker /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
1948*89c4ff92SAndroid Build Coastguard Worker /// overlapped Execution by calling this function from different threads.
CreateWorkingMemHandle(NetworkId networkId)1949*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(NetworkId networkId)
1950*89c4ff92SAndroid Build Coastguard Worker {
1951*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
1952*89c4ff92SAndroid Build Coastguard Worker 
1953*89c4ff92SAndroid Build Coastguard Worker     // Tensors that will need to be allocated internally within armnn
1954*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::unique_ptr<ITensorHandle>> managedTensorHandles;
1955*89c4ff92SAndroid Build Coastguard Worker     // Tensors that will be allocated externally by the user
1956*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::unique_ptr<ITensorHandle>> unmanagedTensorHandles;
1957*89c4ff92SAndroid Build Coastguard Worker 
1958*89c4ff92SAndroid Build Coastguard Worker     std::vector<WorkingMemDescriptor> workingMemDescriptors;
1959*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<BackendId, ExecutionData>> executionDataVec;
1960*89c4ff92SAndroid Build Coastguard Worker 
1961*89c4ff92SAndroid Build Coastguard Worker     auto GetTensorHandle = [&](Layer* layer, const OutputSlot& outputSlot)
1962*89c4ff92SAndroid Build Coastguard Worker     {
1963*89c4ff92SAndroid Build Coastguard Worker         ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
1964*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1965*89c4ff92SAndroid Build Coastguard Worker 
1966*89c4ff92SAndroid Build Coastguard Worker         if (factoryId == ITensorHandleFactory::LegacyFactoryId)
1967*89c4ff92SAndroid Build Coastguard Worker         {
1968*89c4ff92SAndroid Build Coastguard Worker             BackendId id = layer->GetBackendId();
1969*89c4ff92SAndroid Build Coastguard Worker             ARMNN_NO_DEPRECATE_WARN_BEGIN
1970*89c4ff92SAndroid Build Coastguard Worker             return m_WorkloadFactories.at(id)->CreateTensorHandle(tensorInfo, false);
1971*89c4ff92SAndroid Build Coastguard Worker             ARMNN_NO_DEPRECATE_WARN_END
1972*89c4ff92SAndroid Build Coastguard Worker         }
1973*89c4ff92SAndroid Build Coastguard Worker         else
1974*89c4ff92SAndroid Build Coastguard Worker         {
1975*89c4ff92SAndroid Build Coastguard Worker             ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1976*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(handleFactory);
1977*89c4ff92SAndroid Build Coastguard Worker             return handleFactory->CreateTensorHandle(tensorInfo, false);
1978*89c4ff92SAndroid Build Coastguard Worker         }
1979*89c4ff92SAndroid Build Coastguard Worker     };
1980*89c4ff92SAndroid Build Coastguard Worker 
1981*89c4ff92SAndroid Build Coastguard Worker     struct HandleInfo
1982*89c4ff92SAndroid Build Coastguard Worker     {
1983*89c4ff92SAndroid Build Coastguard Worker         ITensorHandle* m_TensorHandle;
1984*89c4ff92SAndroid Build Coastguard Worker 
1985*89c4ff92SAndroid Build Coastguard Worker         bool m_IsInputLayerHandle = false;
1986*89c4ff92SAndroid Build Coastguard Worker         bool m_IsOutputLayerHandle = false;
1987*89c4ff92SAndroid Build Coastguard Worker 
1988*89c4ff92SAndroid Build Coastguard Worker         WorkingMemHandle::InputMemDescriptorCoords m_InputMemDescriptorCoords;
1989*89c4ff92SAndroid Build Coastguard Worker         WorkingMemHandle::OutputMemDescriptorCoords m_OutputMemDescriptorCoords;
1990*89c4ff92SAndroid Build Coastguard Worker     };
1991*89c4ff92SAndroid Build Coastguard Worker 
1992*89c4ff92SAndroid Build Coastguard Worker     std::unordered_map<const OutputSlot*, HandleInfo> outputToHandleInfoMap;
1993*89c4ff92SAndroid Build Coastguard Worker 
1994*89c4ff92SAndroid Build Coastguard Worker     unsigned int layerIndex = 0;
1995*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
1996*89c4ff92SAndroid Build Coastguard Worker     {
1997*89c4ff92SAndroid Build Coastguard Worker         // Constant layers execution and management is handled during loaded network construction
1998*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == LayerType::Constant)
1999*89c4ff92SAndroid Build Coastguard Worker         {
2000*89c4ff92SAndroid Build Coastguard Worker             continue;
2001*89c4ff92SAndroid Build Coastguard Worker         }
2002*89c4ff92SAndroid Build Coastguard Worker 
2003*89c4ff92SAndroid Build Coastguard Worker         WorkingMemDescriptor workingMemDescriptor;
2004*89c4ff92SAndroid Build Coastguard Worker 
2005*89c4ff92SAndroid Build Coastguard Worker         bool isMemoryManaged = true;
2006*89c4ff92SAndroid Build Coastguard Worker         bool isInputLayer = false;
2007*89c4ff92SAndroid Build Coastguard Worker         bool isOutputLayer = false;
2008*89c4ff92SAndroid Build Coastguard Worker         bool isConnectedToOutputLayer = false;
2009*89c4ff92SAndroid Build Coastguard Worker 
2010*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == LayerType::Input || layer->GetType() == LayerType::MemImport)
2011*89c4ff92SAndroid Build Coastguard Worker         {
2012*89c4ff92SAndroid Build Coastguard Worker             // Input layers/workloads will not be executed so the descriptor is not added to workingMemDescriptors
2013*89c4ff92SAndroid Build Coastguard Worker             // However we will still need to manage the tensorHandle
2014*89c4ff92SAndroid Build Coastguard Worker             isInputLayer = true;
2015*89c4ff92SAndroid Build Coastguard Worker             isMemoryManaged = !m_NetworkProperties.m_ImportEnabled;
2016*89c4ff92SAndroid Build Coastguard Worker         }
2017*89c4ff92SAndroid Build Coastguard Worker         else if (layer->GetType() == LayerType::Output)
2018*89c4ff92SAndroid Build Coastguard Worker         {
2019*89c4ff92SAndroid Build Coastguard Worker             isOutputLayer = true;
2020*89c4ff92SAndroid Build Coastguard Worker         }
2021*89c4ff92SAndroid Build Coastguard Worker 
2022*89c4ff92SAndroid Build Coastguard Worker         unsigned int slotIndex = 0;
2023*89c4ff92SAndroid Build Coastguard Worker         // Create a tensor handle for each output slot of a layer
2024*89c4ff92SAndroid Build Coastguard Worker         // Once we create it, we start managing its lifetime
2025*89c4ff92SAndroid Build Coastguard Worker         for (auto& slot : layer->GetOutputSlots())
2026*89c4ff92SAndroid Build Coastguard Worker         {
2027*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int i = 0; i < slot.GetNumConnections(); ++i)
2028*89c4ff92SAndroid Build Coastguard Worker             {
2029*89c4ff92SAndroid Build Coastguard Worker                 if ((slot.GetConnection(i)->GetOwningLayer().GetType() == LayerType::Output))
2030*89c4ff92SAndroid Build Coastguard Worker                 {
2031*89c4ff92SAndroid Build Coastguard Worker                     if (!isConnectedToOutputLayer)
2032*89c4ff92SAndroid Build Coastguard Worker                     {
2033*89c4ff92SAndroid Build Coastguard Worker                         isConnectedToOutputLayer = true;
2034*89c4ff92SAndroid Build Coastguard Worker                         // If Export is enabled disable memory management, so we can export, otherwise we do a copy
2035*89c4ff92SAndroid Build Coastguard Worker                         isMemoryManaged = !m_NetworkProperties.m_ExportEnabled;
2036*89c4ff92SAndroid Build Coastguard Worker                     }
2037*89c4ff92SAndroid Build Coastguard Worker                     else
2038*89c4ff92SAndroid Build Coastguard Worker                     {
2039*89c4ff92SAndroid Build Coastguard Worker                         // Importing in this case would likely cause unexpected behaviour, so we disallow it.
2040*89c4ff92SAndroid Build Coastguard Worker                         ARMNN_LOG(warning) <<
2041*89c4ff92SAndroid Build Coastguard Worker                            fmt::format("Layer name: '{0}' guid: '{1}' has two or more OutputLayers connected to it. "
2042*89c4ff92SAndroid Build Coastguard Worker                                        "This will prevent importing on the connected OutputLayers.",
2043*89c4ff92SAndroid Build Coastguard Worker                                         layer->GetName(), layer->GetGuid());
2044*89c4ff92SAndroid Build Coastguard Worker                         isMemoryManaged = true;
2045*89c4ff92SAndroid Build Coastguard Worker                     }
2046*89c4ff92SAndroid Build Coastguard Worker                 }
2047*89c4ff92SAndroid Build Coastguard Worker             }
2048*89c4ff92SAndroid Build Coastguard Worker 
2049*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* tensorHandle;
2050*89c4ff92SAndroid Build Coastguard Worker             if (isMemoryManaged)
2051*89c4ff92SAndroid Build Coastguard Worker             {
2052*89c4ff92SAndroid Build Coastguard Worker                 managedTensorHandles.emplace_back(GetTensorHandle(layer, slot));
2053*89c4ff92SAndroid Build Coastguard Worker                 tensorHandle = managedTensorHandles.back().get();
2054*89c4ff92SAndroid Build Coastguard Worker             }
2055*89c4ff92SAndroid Build Coastguard Worker             else
2056*89c4ff92SAndroid Build Coastguard Worker             {
2057*89c4ff92SAndroid Build Coastguard Worker                 unmanagedTensorHandles.emplace_back(GetTensorHandle(layer, slot));
2058*89c4ff92SAndroid Build Coastguard Worker                 tensorHandle = unmanagedTensorHandles.back().get();
2059*89c4ff92SAndroid Build Coastguard Worker             }
2060*89c4ff92SAndroid Build Coastguard Worker 
2061*89c4ff92SAndroid Build Coastguard Worker             workingMemDescriptor.m_Outputs.push_back(tensorHandle);
2062*89c4ff92SAndroid Build Coastguard Worker 
2063*89c4ff92SAndroid Build Coastguard Worker             HandleInfo& handleInfo = outputToHandleInfoMap[&slot];
2064*89c4ff92SAndroid Build Coastguard Worker             handleInfo.m_TensorHandle = tensorHandle;
2065*89c4ff92SAndroid Build Coastguard Worker 
2066*89c4ff92SAndroid Build Coastguard Worker             // Store the coordinates of the current layer's OutputSlot that is connected to the OutputLayer
2067*89c4ff92SAndroid Build Coastguard Worker             if (isConnectedToOutputLayer)
2068*89c4ff92SAndroid Build Coastguard Worker             {
2069*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_IsOutputLayerHandle = true;
2070*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_OutputMemDescriptorCoords.m_OutputSlotCoords = {layerIndex, slotIndex};
2071*89c4ff92SAndroid Build Coastguard Worker             }
2072*89c4ff92SAndroid Build Coastguard Worker             // Store the LayerBindingId of the InputLayer
2073*89c4ff92SAndroid Build Coastguard Worker             if (isInputLayer)
2074*89c4ff92SAndroid Build Coastguard Worker             {
2075*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_IsInputLayerHandle = true;
2076*89c4ff92SAndroid Build Coastguard Worker                 LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2077*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_InputMemDescriptorCoords.m_LayerBindingId = bindingId;
2078*89c4ff92SAndroid Build Coastguard Worker             }
2079*89c4ff92SAndroid Build Coastguard Worker             slotIndex++;
2080*89c4ff92SAndroid Build Coastguard Worker         }
2081*89c4ff92SAndroid Build Coastguard Worker         // Loop through the input slots in the same layer and decrement the reference counter associated
2082*89c4ff92SAndroid Build Coastguard Worker         // to each tensor handle we encounter.
2083*89c4ff92SAndroid Build Coastguard Worker         // Once it reaches zero, the lifetime of the tensor handle has ended, and we mark its memory as available
2084*89c4ff92SAndroid Build Coastguard Worker         // so that the next tensor handle with a non overlapping lifetime can share its memory.
2085*89c4ff92SAndroid Build Coastguard Worker         for (auto& slot : layer->GetInputSlots())
2086*89c4ff92SAndroid Build Coastguard Worker         {
2087*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(slot.GetConnection());
2088*89c4ff92SAndroid Build Coastguard Worker             auto outputSlot = slot.GetConnectedOutputSlot();
2089*89c4ff92SAndroid Build Coastguard Worker             auto key = outputSlot->GetOwningLayer().GetGuid();
2090*89c4ff92SAndroid Build Coastguard Worker 
2091*89c4ff92SAndroid Build Coastguard Worker             // Constant layers execution and management is handled during loaded network construction
2092*89c4ff92SAndroid Build Coastguard Worker             auto found = m_ConstantTensorHandles.find(key);
2093*89c4ff92SAndroid Build Coastguard Worker             if (found != m_ConstantTensorHandles.end())
2094*89c4ff92SAndroid Build Coastguard Worker             {
2095*89c4ff92SAndroid Build Coastguard Worker                 ITensorHandle* tensorHandle = found->second;
2096*89c4ff92SAndroid Build Coastguard Worker                 workingMemDescriptor.m_Inputs.push_back(tensorHandle);
2097*89c4ff92SAndroid Build Coastguard Worker 
2098*89c4ff92SAndroid Build Coastguard Worker                 // Odd case where a constant layer is connected to an output layer
2099*89c4ff92SAndroid Build Coastguard Worker                 // We will need to create a HandleInfo to track it
2100*89c4ff92SAndroid Build Coastguard Worker                 if (isOutputLayer)
2101*89c4ff92SAndroid Build Coastguard Worker                 {
2102*89c4ff92SAndroid Build Coastguard Worker                     LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2103*89c4ff92SAndroid Build Coastguard Worker 
2104*89c4ff92SAndroid Build Coastguard Worker                     HandleInfo& handleInfo = outputToHandleInfoMap[outputSlot];
2105*89c4ff92SAndroid Build Coastguard Worker                     handleInfo.m_TensorHandle = tensorHandle;
2106*89c4ff92SAndroid Build Coastguard Worker                     handleInfo.m_IsOutputLayerHandle = true;
2107*89c4ff92SAndroid Build Coastguard Worker                     handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
2108*89c4ff92SAndroid Build Coastguard Worker                     handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
2109*89c4ff92SAndroid Build Coastguard Worker                 }
2110*89c4ff92SAndroid Build Coastguard Worker                 continue;
2111*89c4ff92SAndroid Build Coastguard Worker             }
2112*89c4ff92SAndroid Build Coastguard Worker 
2113*89c4ff92SAndroid Build Coastguard Worker             HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot);
2114*89c4ff92SAndroid Build Coastguard Worker 
2115*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle;
2116*89c4ff92SAndroid Build Coastguard Worker             workingMemDescriptor.m_Inputs.push_back(inputTensorHandle);
2117*89c4ff92SAndroid Build Coastguard Worker 
2118*89c4ff92SAndroid Build Coastguard Worker             // Store the LayerBindingId of the OutputLayer
2119*89c4ff92SAndroid Build Coastguard Worker             if (isOutputLayer)
2120*89c4ff92SAndroid Build Coastguard Worker             {
2121*89c4ff92SAndroid Build Coastguard Worker                 LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2122*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
2123*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
2124*89c4ff92SAndroid Build Coastguard Worker             }
2125*89c4ff92SAndroid Build Coastguard Worker             // In this case the layer is not an Output Layer but shares its input tensorhandle with an OutputLayer
2126*89c4ff92SAndroid Build Coastguard Worker             // It will need to be updated as well, if we swap out the tensorhandle
2127*89c4ff92SAndroid Build Coastguard Worker             else if (handleInfo.m_IsOutputLayerHandle)
2128*89c4ff92SAndroid Build Coastguard Worker             {
2129*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, slot.GetSlotIndex()});
2130*89c4ff92SAndroid Build Coastguard Worker             }
2131*89c4ff92SAndroid Build Coastguard Worker 
2132*89c4ff92SAndroid Build Coastguard Worker             // Store the coordinates of the InputSlots connected to the InputLayer
2133*89c4ff92SAndroid Build Coastguard Worker             // There can be more than one InputSlot connected to an InputLayer, so we use a vector
2134*89c4ff92SAndroid Build Coastguard Worker             if (handleInfo.m_IsInputLayerHandle)
2135*89c4ff92SAndroid Build Coastguard Worker             {
2136*89c4ff92SAndroid Build Coastguard Worker                 std::pair<LayerGuid, unsigned int> connectionLocation{layerIndex, slot.GetSlotIndex()};
2137*89c4ff92SAndroid Build Coastguard Worker                 handleInfo.m_InputMemDescriptorCoords.m_InputSlotCoords.emplace_back(connectionLocation);
2138*89c4ff92SAndroid Build Coastguard Worker             }
2139*89c4ff92SAndroid Build Coastguard Worker         }
2140*89c4ff92SAndroid Build Coastguard Worker 
2141*89c4ff92SAndroid Build Coastguard Worker         // Input/Output layers/workloads will not be executed, so the descriptor is not added to workingMemDescriptors
2142*89c4ff92SAndroid Build Coastguard Worker         // However we will still need to manage the tensorHandle
2143*89c4ff92SAndroid Build Coastguard Worker         if (!isInputLayer)
2144*89c4ff92SAndroid Build Coastguard Worker         {
2145*89c4ff92SAndroid Build Coastguard Worker             // Simply auto initialise ExecutionData here, so it's added only for the layer that require execution.
2146*89c4ff92SAndroid Build Coastguard Worker             // The memory and data will be allocated/assigned for the void* in WorkingMemHandle::Allocate.
2147*89c4ff92SAndroid Build Coastguard Worker             std::pair<BackendId, ExecutionData> dataPair;
2148*89c4ff92SAndroid Build Coastguard Worker             dataPair.first = layer->GetBackendId();
2149*89c4ff92SAndroid Build Coastguard Worker 
2150*89c4ff92SAndroid Build Coastguard Worker             executionDataVec.push_back(dataPair);
2151*89c4ff92SAndroid Build Coastguard Worker             workingMemDescriptors.push_back(workingMemDescriptor);
2152*89c4ff92SAndroid Build Coastguard Worker 
2153*89c4ff92SAndroid Build Coastguard Worker             layerIndex++;
2154*89c4ff92SAndroid Build Coastguard Worker         }
2155*89c4ff92SAndroid Build Coastguard Worker     }
2156*89c4ff92SAndroid Build Coastguard Worker 
2157*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<std::shared_ptr<TensorMemory>, MemorySource>> tensorMemory;
2158*89c4ff92SAndroid Build Coastguard Worker 
2159*89c4ff92SAndroid Build Coastguard Worker     auto externalMemoryManager = CreateExternalMemoryManger(tensorMemory);
2160*89c4ff92SAndroid Build Coastguard Worker 
2161*89c4ff92SAndroid Build Coastguard Worker     // Sort m_TensorMemory, so it's order matches the outputSlot order
2162*89c4ff92SAndroid Build Coastguard Worker     std::sort(tensorMemory.begin(), tensorMemory.end(),
2163*89c4ff92SAndroid Build Coastguard Worker               [](const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& lhs,
2164*89c4ff92SAndroid Build Coastguard Worker                  const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& rhs)
2165*89c4ff92SAndroid Build Coastguard Worker               {
2166*89c4ff92SAndroid Build Coastguard Worker                   return lhs.first->m_OutputSlotId < rhs.first->m_OutputSlotId;
2167*89c4ff92SAndroid Build Coastguard Worker               });
2168*89c4ff92SAndroid Build Coastguard Worker 
2169*89c4ff92SAndroid Build Coastguard Worker     std::vector<WorkingMemHandle::InputMemDescriptorCoords> inputConnectionsInfo;
2170*89c4ff92SAndroid Build Coastguard Worker     std::vector<WorkingMemHandle::OutputMemDescriptorCoords> outputConnectionsInfo;
2171*89c4ff92SAndroid Build Coastguard Worker 
2172*89c4ff92SAndroid Build Coastguard Worker     for (const auto& handleInfo: outputToHandleInfoMap)
2173*89c4ff92SAndroid Build Coastguard Worker     {
2174*89c4ff92SAndroid Build Coastguard Worker         if (handleInfo.second.m_IsOutputLayerHandle)
2175*89c4ff92SAndroid Build Coastguard Worker         {
2176*89c4ff92SAndroid Build Coastguard Worker             outputConnectionsInfo.emplace_back(handleInfo.second.m_OutputMemDescriptorCoords);
2177*89c4ff92SAndroid Build Coastguard Worker         }
2178*89c4ff92SAndroid Build Coastguard Worker 
2179*89c4ff92SAndroid Build Coastguard Worker         if (handleInfo.second.m_IsInputLayerHandle)
2180*89c4ff92SAndroid Build Coastguard Worker         {
2181*89c4ff92SAndroid Build Coastguard Worker             inputConnectionsInfo.emplace_back(handleInfo.second.m_InputMemDescriptorCoords);
2182*89c4ff92SAndroid Build Coastguard Worker         }
2183*89c4ff92SAndroid Build Coastguard Worker     }
2184*89c4ff92SAndroid Build Coastguard Worker 
2185*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<WorkingMemHandle>(networkId,
2186*89c4ff92SAndroid Build Coastguard Worker                                               inputConnectionsInfo,
2187*89c4ff92SAndroid Build Coastguard Worker                                               outputConnectionsInfo,
2188*89c4ff92SAndroid Build Coastguard Worker                                               workingMemDescriptors,
2189*89c4ff92SAndroid Build Coastguard Worker                                               std::move(externalMemoryManager),
2190*89c4ff92SAndroid Build Coastguard Worker                                               std::move(tensorMemory),
2191*89c4ff92SAndroid Build Coastguard Worker                                               std::move(managedTensorHandles),
2192*89c4ff92SAndroid Build Coastguard Worker                                               std::move(unmanagedTensorHandles),
2193*89c4ff92SAndroid Build Coastguard Worker                                               executionDataVec,
2194*89c4ff92SAndroid Build Coastguard Worker                                               &m_Backends);
2195*89c4ff92SAndroid Build Coastguard Worker }
2196*89c4ff92SAndroid Build Coastguard Worker 
RegisterDebugCallback(const DebugCallbackFunction & func)2197*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func)
2198*89c4ff92SAndroid Build Coastguard Worker {
2199*89c4ff92SAndroid Build Coastguard Worker     for (auto&& workloadPtr: m_WorkloadQueue)
2200*89c4ff92SAndroid Build Coastguard Worker     {
2201*89c4ff92SAndroid Build Coastguard Worker         workloadPtr.get()->RegisterDebugCallback(func);
2202*89c4ff92SAndroid Build Coastguard Worker     }
2203*89c4ff92SAndroid Build Coastguard Worker }
2204*89c4ff92SAndroid Build Coastguard Worker 
2205*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryProfileAsync()2206*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::CreateMemoryProfileAsync()
2207*89c4ff92SAndroid Build Coastguard Worker {
2208*89c4ff92SAndroid Build Coastguard Worker     struct PartialBlock
2209*89c4ff92SAndroid Build Coastguard Worker     {
2210*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_StartOfLife;
2211*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_Lifetime;
2212*89c4ff92SAndroid Build Coastguard Worker 
2213*89c4ff92SAndroid Build Coastguard Worker         size_t m_MemSize;
2214*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_Index;
2215*89c4ff92SAndroid Build Coastguard Worker 
2216*89c4ff92SAndroid Build Coastguard Worker         BackendId m_BackendId;
2217*89c4ff92SAndroid Build Coastguard Worker     };
2218*89c4ff92SAndroid Build Coastguard Worker 
2219*89c4ff92SAndroid Build Coastguard Worker     auto align = [](size_t numToAlign)
2220*89c4ff92SAndroid Build Coastguard Worker     {
2221*89c4ff92SAndroid Build Coastguard Worker         const size_t alignment = sizeof(float);
2222*89c4ff92SAndroid Build Coastguard Worker         return ((numToAlign + alignment - 1) / alignment) * alignment;
2223*89c4ff92SAndroid Build Coastguard Worker     };
2224*89c4ff92SAndroid Build Coastguard Worker 
2225*89c4ff92SAndroid Build Coastguard Worker     std::unordered_map<const OutputSlot*, PartialBlock> memBlockTrackerMap;
2226*89c4ff92SAndroid Build Coastguard Worker 
2227*89c4ff92SAndroid Build Coastguard Worker     const bool inputImportingEnabled = m_NetworkProperties.m_InputSource != MemorySource::Undefined;
2228*89c4ff92SAndroid Build Coastguard Worker     const bool outputImportingEnabled = m_NetworkProperties.m_OutputSource != MemorySource::Undefined;
2229*89c4ff92SAndroid Build Coastguard Worker 
2230*89c4ff92SAndroid Build Coastguard Worker     unsigned int timestep = 0;
2231*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputIndex = 0;
2232*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
2233*89c4ff92SAndroid Build Coastguard Worker 
2234*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
2235*89c4ff92SAndroid Build Coastguard Worker     {
2236*89c4ff92SAndroid Build Coastguard Worker         const LayerType& layerType = layer->GetType();
2237*89c4ff92SAndroid Build Coastguard Worker         // Don't manage memory if importing.
2238*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Input && inputImportingEnabled)
2239*89c4ff92SAndroid Build Coastguard Worker         {
2240*89c4ff92SAndroid Build Coastguard Worker             continue;
2241*89c4ff92SAndroid Build Coastguard Worker         }
2242*89c4ff92SAndroid Build Coastguard Worker         // Don't manage memory if importing.
2243*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Output && outputImportingEnabled
2244*89c4ff92SAndroid Build Coastguard Worker             && layer->GetInputSlot(0).GetConnectedOutputSlot()->GetNumConnections() == 1)
2245*89c4ff92SAndroid Build Coastguard Worker         {
2246*89c4ff92SAndroid Build Coastguard Worker             continue;
2247*89c4ff92SAndroid Build Coastguard Worker         }
2248*89c4ff92SAndroid Build Coastguard Worker         // Because Constant Layer memory can not be shared, the memory must persist for the lifetime of execution,
2249*89c4ff92SAndroid Build Coastguard Worker         // management is done separately.
2250*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Constant)
2251*89c4ff92SAndroid Build Coastguard Worker         {
2252*89c4ff92SAndroid Build Coastguard Worker             continue;
2253*89c4ff92SAndroid Build Coastguard Worker         }
2254*89c4ff92SAndroid Build Coastguard Worker 
2255*89c4ff92SAndroid Build Coastguard Worker         BackendId backendId = layer->GetBackendId();
2256*89c4ff92SAndroid Build Coastguard Worker         for (auto& outputSlot : layer->GetOutputSlots())
2257*89c4ff92SAndroid Build Coastguard Worker         {
2258*89c4ff92SAndroid Build Coastguard Worker             if (!m_SupportsExternallyManagedMemory[backendId])
2259*89c4ff92SAndroid Build Coastguard Worker             {
2260*89c4ff92SAndroid Build Coastguard Worker                 continue;
2261*89c4ff92SAndroid Build Coastguard Worker             }
2262*89c4ff92SAndroid Build Coastguard Worker 
2263*89c4ff92SAndroid Build Coastguard Worker             PartialBlock partialBlock;
2264*89c4ff92SAndroid Build Coastguard Worker 
2265*89c4ff92SAndroid Build Coastguard Worker             partialBlock.m_StartOfLife = timestep;
2266*89c4ff92SAndroid Build Coastguard Worker 
2267*89c4ff92SAndroid Build Coastguard Worker             size_t alignedSize = align(outputSlot.GetOutputHandler().GetTensorInfo().GetNumBytes());
2268*89c4ff92SAndroid Build Coastguard Worker             partialBlock.m_MemSize = alignedSize;
2269*89c4ff92SAndroid Build Coastguard Worker             partialBlock.m_Index = outputIndex++;
2270*89c4ff92SAndroid Build Coastguard Worker             partialBlock.m_Lifetime = outputSlot.GetNumConnections();
2271*89c4ff92SAndroid Build Coastguard Worker             partialBlock.m_BackendId = backendId;
2272*89c4ff92SAndroid Build Coastguard Worker 
2273*89c4ff92SAndroid Build Coastguard Worker             if (partialBlock.m_Lifetime == 0)
2274*89c4ff92SAndroid Build Coastguard Worker             {
2275*89c4ff92SAndroid Build Coastguard Worker                 m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2276*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_StartOfLife,
2277*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_MemSize,
2278*89c4ff92SAndroid Build Coastguard Worker                                                                      0,
2279*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_Index);
2280*89c4ff92SAndroid Build Coastguard Worker             }
2281*89c4ff92SAndroid Build Coastguard Worker             else
2282*89c4ff92SAndroid Build Coastguard Worker             {
2283*89c4ff92SAndroid Build Coastguard Worker                 memBlockTrackerMap[&outputSlot] = partialBlock;
2284*89c4ff92SAndroid Build Coastguard Worker             }
2285*89c4ff92SAndroid Build Coastguard Worker         }
2286*89c4ff92SAndroid Build Coastguard Worker 
2287*89c4ff92SAndroid Build Coastguard Worker         for (auto& inputSlot : layer->GetInputSlots())
2288*89c4ff92SAndroid Build Coastguard Worker         {
2289*89c4ff92SAndroid Build Coastguard Worker             const Layer& connectedInputLayer = inputSlot.GetConnectedOutputSlot()->GetOwningLayer();
2290*89c4ff92SAndroid Build Coastguard Worker             const LayerType& owningLayerType = connectedInputLayer.GetType();
2291*89c4ff92SAndroid Build Coastguard Worker 
2292*89c4ff92SAndroid Build Coastguard Worker             if (owningLayerType == LayerType::Constant)
2293*89c4ff92SAndroid Build Coastguard Worker             {
2294*89c4ff92SAndroid Build Coastguard Worker                 continue;
2295*89c4ff92SAndroid Build Coastguard Worker             }
2296*89c4ff92SAndroid Build Coastguard Worker             if (inputImportingEnabled && owningLayerType == LayerType::Input)
2297*89c4ff92SAndroid Build Coastguard Worker             {
2298*89c4ff92SAndroid Build Coastguard Worker                 continue;
2299*89c4ff92SAndroid Build Coastguard Worker             }
2300*89c4ff92SAndroid Build Coastguard Worker 
2301*89c4ff92SAndroid Build Coastguard Worker             auto outputSlot = inputSlot.GetConnectedOutputSlot();
2302*89c4ff92SAndroid Build Coastguard Worker 
2303*89c4ff92SAndroid Build Coastguard Worker             PartialBlock& partialBlock = memBlockTrackerMap.at(outputSlot);
2304*89c4ff92SAndroid Build Coastguard Worker 
2305*89c4ff92SAndroid Build Coastguard Worker             auto& lifetime = partialBlock.m_Lifetime;
2306*89c4ff92SAndroid Build Coastguard Worker             --lifetime;
2307*89c4ff92SAndroid Build Coastguard Worker 
2308*89c4ff92SAndroid Build Coastguard Worker             if (lifetime == 0)
2309*89c4ff92SAndroid Build Coastguard Worker             {
2310*89c4ff92SAndroid Build Coastguard Worker                 m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2311*89c4ff92SAndroid Build Coastguard Worker                                                                      timestep,
2312*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_MemSize,
2313*89c4ff92SAndroid Build Coastguard Worker                                                                      0,
2314*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_Index);
2315*89c4ff92SAndroid Build Coastguard Worker             }
2316*89c4ff92SAndroid Build Coastguard Worker         }
2317*89c4ff92SAndroid Build Coastguard Worker         ++timestep;
2318*89c4ff92SAndroid Build Coastguard Worker     }
2319*89c4ff92SAndroid Build Coastguard Worker }
2320*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryProfile()2321*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::CreateMemoryProfile()
2322*89c4ff92SAndroid Build Coastguard Worker {
2323*89c4ff92SAndroid Build Coastguard Worker     // Finds the first TensorHandle ancestor of a SubTensorHandle. If the ITensorHandle provided
2324*89c4ff92SAndroid Build Coastguard Worker     // is a TensorHandle, the function just returns it
2325*89c4ff92SAndroid Build Coastguard Worker     auto TraceSubTensorHandleAncestry = [](ITensorHandle* const subTensorHandle)
2326*89c4ff92SAndroid Build Coastguard Worker     {
2327*89c4ff92SAndroid Build Coastguard Worker         ITensorHandle* ancestor = subTensorHandle;
2328*89c4ff92SAndroid Build Coastguard Worker         while (ancestor && ancestor->GetParent())
2329*89c4ff92SAndroid Build Coastguard Worker         {
2330*89c4ff92SAndroid Build Coastguard Worker             ancestor = ancestor->GetParent();
2331*89c4ff92SAndroid Build Coastguard Worker         }
2332*89c4ff92SAndroid Build Coastguard Worker         return ancestor;
2333*89c4ff92SAndroid Build Coastguard Worker     };
2334*89c4ff92SAndroid Build Coastguard Worker 
2335*89c4ff92SAndroid Build Coastguard Worker     struct PartialBlock
2336*89c4ff92SAndroid Build Coastguard Worker     {
2337*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_StartOfLife;
2338*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_Lifetime;
2339*89c4ff92SAndroid Build Coastguard Worker 
2340*89c4ff92SAndroid Build Coastguard Worker         size_t m_MemSize;
2341*89c4ff92SAndroid Build Coastguard Worker         unsigned int m_Index;
2342*89c4ff92SAndroid Build Coastguard Worker 
2343*89c4ff92SAndroid Build Coastguard Worker         BackendId m_BackendId;
2344*89c4ff92SAndroid Build Coastguard Worker     };
2345*89c4ff92SAndroid Build Coastguard Worker 
2346*89c4ff92SAndroid Build Coastguard Worker     auto align = [](size_t numToAlign)
2347*89c4ff92SAndroid Build Coastguard Worker     {
2348*89c4ff92SAndroid Build Coastguard Worker         const size_t alignment = sizeof(float);
2349*89c4ff92SAndroid Build Coastguard Worker         return ((numToAlign + alignment - 1) / alignment) * alignment;
2350*89c4ff92SAndroid Build Coastguard Worker     };
2351*89c4ff92SAndroid Build Coastguard Worker 
2352*89c4ff92SAndroid Build Coastguard Worker     std::unordered_map<ITensorHandle*, PartialBlock> memBlockTrackerMap;
2353*89c4ff92SAndroid Build Coastguard Worker 
2354*89c4ff92SAndroid Build Coastguard Worker     const bool inputImportingEnabled = m_NetworkProperties.m_InputSource != MemorySource::Undefined;
2355*89c4ff92SAndroid Build Coastguard Worker     const bool outputImportingEnabled = m_NetworkProperties.m_OutputSource != MemorySource::Undefined;
2356*89c4ff92SAndroid Build Coastguard Worker 
2357*89c4ff92SAndroid Build Coastguard Worker     unsigned int timestep = 0;
2358*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputIndex = 0;
2359*89c4ff92SAndroid Build Coastguard Worker     Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
2360*89c4ff92SAndroid Build Coastguard Worker 
2361*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : order)
2362*89c4ff92SAndroid Build Coastguard Worker     {
2363*89c4ff92SAndroid Build Coastguard Worker         const LayerType& layerType = layer->GetType();
2364*89c4ff92SAndroid Build Coastguard Worker         // Don't manage memory if importing.
2365*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Input && inputImportingEnabled)
2366*89c4ff92SAndroid Build Coastguard Worker         {
2367*89c4ff92SAndroid Build Coastguard Worker             continue;
2368*89c4ff92SAndroid Build Coastguard Worker         }
2369*89c4ff92SAndroid Build Coastguard Worker         // Don't manage memory if importing.
2370*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Output && outputImportingEnabled
2371*89c4ff92SAndroid Build Coastguard Worker             && layer->GetInputSlot(0).GetConnectedOutputSlot()->GetNumConnections() == 1)
2372*89c4ff92SAndroid Build Coastguard Worker         {
2373*89c4ff92SAndroid Build Coastguard Worker             continue;
2374*89c4ff92SAndroid Build Coastguard Worker         }
2375*89c4ff92SAndroid Build Coastguard Worker         // Because Constant Layer memory can not be shared, the memory must persist for the lifetime of execution,
2376*89c4ff92SAndroid Build Coastguard Worker         // management is done separately.
2377*89c4ff92SAndroid Build Coastguard Worker         if (layerType == LayerType::Constant)
2378*89c4ff92SAndroid Build Coastguard Worker         {
2379*89c4ff92SAndroid Build Coastguard Worker             continue;
2380*89c4ff92SAndroid Build Coastguard Worker         }
2381*89c4ff92SAndroid Build Coastguard Worker 
2382*89c4ff92SAndroid Build Coastguard Worker         BackendId backendId = layer->GetBackendId();
2383*89c4ff92SAndroid Build Coastguard Worker         for (auto& outputSlot : layer->GetOutputSlots())
2384*89c4ff92SAndroid Build Coastguard Worker         {
2385*89c4ff92SAndroid Build Coastguard Worker             if (!m_SupportsExternallyManagedMemory[backendId])
2386*89c4ff92SAndroid Build Coastguard Worker             {
2387*89c4ff92SAndroid Build Coastguard Worker                 continue;
2388*89c4ff92SAndroid Build Coastguard Worker             }
2389*89c4ff92SAndroid Build Coastguard Worker 
2390*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* tensorHandle = outputSlot.GetOutputHandler().GetData();
2391*89c4ff92SAndroid Build Coastguard Worker             tensorHandle = TraceSubTensorHandleAncestry(tensorHandle);
2392*89c4ff92SAndroid Build Coastguard Worker 
2393*89c4ff92SAndroid Build Coastguard Worker             if (memBlockTrackerMap.find(tensorHandle) == memBlockTrackerMap.end())
2394*89c4ff92SAndroid Build Coastguard Worker             {
2395*89c4ff92SAndroid Build Coastguard Worker                 PartialBlock partialBlock;
2396*89c4ff92SAndroid Build Coastguard Worker 
2397*89c4ff92SAndroid Build Coastguard Worker                 partialBlock.m_StartOfLife = timestep;
2398*89c4ff92SAndroid Build Coastguard Worker 
2399*89c4ff92SAndroid Build Coastguard Worker                 size_t alignedSize = align(outputSlot.GetOutputHandler().GetTensorInfo().GetNumBytes());
2400*89c4ff92SAndroid Build Coastguard Worker                 partialBlock.m_MemSize = alignedSize;
2401*89c4ff92SAndroid Build Coastguard Worker                 partialBlock.m_Index = outputIndex++;
2402*89c4ff92SAndroid Build Coastguard Worker                 partialBlock.m_Lifetime = outputSlot.GetNumConnections();
2403*89c4ff92SAndroid Build Coastguard Worker                 partialBlock.m_BackendId = backendId;
2404*89c4ff92SAndroid Build Coastguard Worker 
2405*89c4ff92SAndroid Build Coastguard Worker                 if (partialBlock.m_Lifetime == 0)
2406*89c4ff92SAndroid Build Coastguard Worker                 {
2407*89c4ff92SAndroid Build Coastguard Worker                     m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2408*89c4ff92SAndroid Build Coastguard Worker                                                                          partialBlock.m_StartOfLife,
2409*89c4ff92SAndroid Build Coastguard Worker                                                                          partialBlock.m_MemSize,
2410*89c4ff92SAndroid Build Coastguard Worker                                                                          0,
2411*89c4ff92SAndroid Build Coastguard Worker                                                                          partialBlock.m_Index);
2412*89c4ff92SAndroid Build Coastguard Worker                 }
2413*89c4ff92SAndroid Build Coastguard Worker                 else
2414*89c4ff92SAndroid Build Coastguard Worker                 {
2415*89c4ff92SAndroid Build Coastguard Worker                     memBlockTrackerMap[tensorHandle] = partialBlock;
2416*89c4ff92SAndroid Build Coastguard Worker                 }
2417*89c4ff92SAndroid Build Coastguard Worker                 m_Tensorhandles.push_back(tensorHandle);
2418*89c4ff92SAndroid Build Coastguard Worker 
2419*89c4ff92SAndroid Build Coastguard Worker             }
2420*89c4ff92SAndroid Build Coastguard Worker             else
2421*89c4ff92SAndroid Build Coastguard Worker             {
2422*89c4ff92SAndroid Build Coastguard Worker                 memBlockTrackerMap.at(tensorHandle).m_Lifetime += outputSlot.GetNumConnections();
2423*89c4ff92SAndroid Build Coastguard Worker             }
2424*89c4ff92SAndroid Build Coastguard Worker         }
2425*89c4ff92SAndroid Build Coastguard Worker 
2426*89c4ff92SAndroid Build Coastguard Worker         for (auto& inputSlot : layer->GetInputSlots())
2427*89c4ff92SAndroid Build Coastguard Worker         {
2428*89c4ff92SAndroid Build Coastguard Worker             const Layer& connectedInputLayer = inputSlot.GetConnectedOutputSlot()->GetOwningLayer();
2429*89c4ff92SAndroid Build Coastguard Worker             const LayerType& owningLayerType = connectedInputLayer.GetType();
2430*89c4ff92SAndroid Build Coastguard Worker 
2431*89c4ff92SAndroid Build Coastguard Worker             if (owningLayerType == LayerType::Constant)
2432*89c4ff92SAndroid Build Coastguard Worker             {
2433*89c4ff92SAndroid Build Coastguard Worker                 continue;
2434*89c4ff92SAndroid Build Coastguard Worker             }
2435*89c4ff92SAndroid Build Coastguard Worker             if (inputImportingEnabled && owningLayerType == LayerType::Input)
2436*89c4ff92SAndroid Build Coastguard Worker             {
2437*89c4ff92SAndroid Build Coastguard Worker                 continue;
2438*89c4ff92SAndroid Build Coastguard Worker             }
2439*89c4ff92SAndroid Build Coastguard Worker             if (!m_SupportsExternallyManagedMemory[connectedInputLayer.GetBackendId()])
2440*89c4ff92SAndroid Build Coastguard Worker             {
2441*89c4ff92SAndroid Build Coastguard Worker                 continue;
2442*89c4ff92SAndroid Build Coastguard Worker             }
2443*89c4ff92SAndroid Build Coastguard Worker 
2444*89c4ff92SAndroid Build Coastguard Worker             auto outputSlot = inputSlot.GetConnectedOutputSlot();
2445*89c4ff92SAndroid Build Coastguard Worker 
2446*89c4ff92SAndroid Build Coastguard Worker             ITensorHandle* tensorHandle = outputSlot->GetOutputHandler().GetData();
2447*89c4ff92SAndroid Build Coastguard Worker             tensorHandle = TraceSubTensorHandleAncestry(tensorHandle);
2448*89c4ff92SAndroid Build Coastguard Worker 
2449*89c4ff92SAndroid Build Coastguard Worker             PartialBlock& partialBlock = memBlockTrackerMap.at(tensorHandle);
2450*89c4ff92SAndroid Build Coastguard Worker 
2451*89c4ff92SAndroid Build Coastguard Worker             auto& lifetime = partialBlock.m_Lifetime;
2452*89c4ff92SAndroid Build Coastguard Worker             --lifetime;
2453*89c4ff92SAndroid Build Coastguard Worker 
2454*89c4ff92SAndroid Build Coastguard Worker             if (lifetime == 0)
2455*89c4ff92SAndroid Build Coastguard Worker             {
2456*89c4ff92SAndroid Build Coastguard Worker                 m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2457*89c4ff92SAndroid Build Coastguard Worker                                                                      timestep,
2458*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_MemSize,
2459*89c4ff92SAndroid Build Coastguard Worker                                                                      0,
2460*89c4ff92SAndroid Build Coastguard Worker                                                                      partialBlock.m_Index);
2461*89c4ff92SAndroid Build Coastguard Worker             }
2462*89c4ff92SAndroid Build Coastguard Worker         }
2463*89c4ff92SAndroid Build Coastguard Worker         ++timestep;
2464*89c4ff92SAndroid Build Coastguard Worker     }
2465*89c4ff92SAndroid Build Coastguard Worker 
2466*89c4ff92SAndroid Build Coastguard Worker }
2467*89c4ff92SAndroid Build Coastguard Worker 
CreateExternalMemoryManger(std::vector<std::pair<std::shared_ptr<TensorMemory>,MemorySource>> & tensorMemoryVec)2468*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MemoryManager> LoadedNetwork::CreateExternalMemoryManger(
2469*89c4ff92SAndroid Build Coastguard Worker         std::vector<std::pair<std::shared_ptr<TensorMemory>, MemorySource>>& tensorMemoryVec)
2470*89c4ff92SAndroid Build Coastguard Worker {
2471*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<MemoryManager> memoryManager = std::make_unique<MemoryManager>();
2472*89c4ff92SAndroid Build Coastguard Worker     auto allocatorMap = BackendRegistryInstance().GetAllocators();
2473*89c4ff92SAndroid Build Coastguard Worker 
2474*89c4ff92SAndroid Build Coastguard Worker     for (auto& backend : m_MemBinMap)
2475*89c4ff92SAndroid Build Coastguard Worker     {
2476*89c4ff92SAndroid Build Coastguard Worker         std::vector<BufferStorage> bufferStorageVec;
2477*89c4ff92SAndroid Build Coastguard Worker 
2478*89c4ff92SAndroid Build Coastguard Worker         std::shared_ptr<ICustomAllocator> backendAllocator;
2479*89c4ff92SAndroid Build Coastguard Worker         if (allocatorMap.find(backend.first) != allocatorMap.end())
2480*89c4ff92SAndroid Build Coastguard Worker         {
2481*89c4ff92SAndroid Build Coastguard Worker             backendAllocator = allocatorMap[backend.first];
2482*89c4ff92SAndroid Build Coastguard Worker         }
2483*89c4ff92SAndroid Build Coastguard Worker         else
2484*89c4ff92SAndroid Build Coastguard Worker         {
2485*89c4ff92SAndroid Build Coastguard Worker             backendAllocator = m_Backends[backend.first]->GetDefaultAllocator();
2486*89c4ff92SAndroid Build Coastguard Worker         }
2487*89c4ff92SAndroid Build Coastguard Worker 
2488*89c4ff92SAndroid Build Coastguard Worker         for (auto& memBin : backend.second)
2489*89c4ff92SAndroid Build Coastguard Worker         {
2490*89c4ff92SAndroid Build Coastguard Worker             BufferStorage bufferStorage;
2491*89c4ff92SAndroid Build Coastguard Worker             bufferStorage.m_BufferSize = memBin.m_MemSize;
2492*89c4ff92SAndroid Build Coastguard Worker             bufferStorage.m_TensorMemoryVector.reserve(memBin.m_MemBlocks.size());
2493*89c4ff92SAndroid Build Coastguard Worker 
2494*89c4ff92SAndroid Build Coastguard Worker             for (auto& memBlock : memBin.m_MemBlocks)
2495*89c4ff92SAndroid Build Coastguard Worker             {
2496*89c4ff92SAndroid Build Coastguard Worker                 auto tensorMemory = std::make_shared<TensorMemory>(TensorMemory{memBlock.m_Offset, memBlock.m_Index});
2497*89c4ff92SAndroid Build Coastguard Worker 
2498*89c4ff92SAndroid Build Coastguard Worker                 tensorMemoryVec.emplace_back(tensorMemory, backendAllocator->GetMemorySourceType());
2499*89c4ff92SAndroid Build Coastguard Worker                 bufferStorage.m_TensorMemoryVector.emplace_back(tensorMemory);
2500*89c4ff92SAndroid Build Coastguard Worker             }
2501*89c4ff92SAndroid Build Coastguard Worker 
2502*89c4ff92SAndroid Build Coastguard Worker             bufferStorageVec.emplace_back(std::move(bufferStorage));
2503*89c4ff92SAndroid Build Coastguard Worker         }
2504*89c4ff92SAndroid Build Coastguard Worker 
2505*89c4ff92SAndroid Build Coastguard Worker         memoryManager->StoreMemToAllocate(bufferStorageVec, backendAllocator, 4);
2506*89c4ff92SAndroid Build Coastguard Worker     }
2507*89c4ff92SAndroid Build Coastguard Worker 
2508*89c4ff92SAndroid Build Coastguard Worker     return memoryManager;
2509*89c4ff92SAndroid Build Coastguard Worker }
2510*89c4ff92SAndroid Build Coastguard Worker 
ValidateImportedInputID(ImportedInputId id)2511*89c4ff92SAndroid Build Coastguard Worker LayerBindingId LoadedNetwork::ValidateImportedInputID(ImportedInputId id)
2512*89c4ff92SAndroid Build Coastguard Worker {
2513*89c4ff92SAndroid Build Coastguard Worker     try
2514*89c4ff92SAndroid Build Coastguard Worker     {
2515*89c4ff92SAndroid Build Coastguard Worker         const auto& importedTensorHandlePin = m_PreImportedInputHandles.at(id);
2516*89c4ff92SAndroid Build Coastguard Worker         if (!importedTensorHandlePin.m_TensorHandle)
2517*89c4ff92SAndroid Build Coastguard Worker         {
2518*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute:"
2519*89c4ff92SAndroid Build Coastguard Worker                                                        "PreImportedInput: {} has been deleted", id));
2520*89c4ff92SAndroid Build Coastguard Worker         }
2521*89c4ff92SAndroid Build Coastguard Worker         return importedTensorHandlePin.m_LayerBindingId;
2522*89c4ff92SAndroid Build Coastguard Worker     }
2523*89c4ff92SAndroid Build Coastguard Worker     catch (const std::out_of_range&)
2524*89c4ff92SAndroid Build Coastguard Worker     {
2525*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: Unknown ImportedInputId: {}", id));
2526*89c4ff92SAndroid Build Coastguard Worker     }
2527*89c4ff92SAndroid Build Coastguard Worker }
2528*89c4ff92SAndroid Build Coastguard Worker 
ValidateImportedOutputID(ImportedOutputId id)2529*89c4ff92SAndroid Build Coastguard Worker LayerBindingId LoadedNetwork::ValidateImportedOutputID(ImportedOutputId id)
2530*89c4ff92SAndroid Build Coastguard Worker {
2531*89c4ff92SAndroid Build Coastguard Worker     try
2532*89c4ff92SAndroid Build Coastguard Worker     {
2533*89c4ff92SAndroid Build Coastguard Worker         const auto& importedTensorHandlePin = m_PreImportedOutputHandles.at(id);
2534*89c4ff92SAndroid Build Coastguard Worker         if (!importedTensorHandlePin.m_TensorHandle)
2535*89c4ff92SAndroid Build Coastguard Worker         {
2536*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: "
2537*89c4ff92SAndroid Build Coastguard Worker                                                        "PreImportedOutput: {} has been deleted", id));
2538*89c4ff92SAndroid Build Coastguard Worker         }
2539*89c4ff92SAndroid Build Coastguard Worker         return importedTensorHandlePin.m_LayerBindingId;
2540*89c4ff92SAndroid Build Coastguard Worker     }
2541*89c4ff92SAndroid Build Coastguard Worker     catch (const std::out_of_range&)
2542*89c4ff92SAndroid Build Coastguard Worker     {
2543*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: Unknown ImportedOutputId: {}", id));
2544*89c4ff92SAndroid Build Coastguard Worker     }
2545*89c4ff92SAndroid Build Coastguard Worker }
2546*89c4ff92SAndroid Build Coastguard Worker 
2547*89c4ff92SAndroid Build Coastguard Worker }
2548