1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "LoadedNetwork.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "Layer.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "Graph.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "Profiling.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "HeapProfiling.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "WorkingMemHandle.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ExecutionData.hpp"
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/MemCopyWorkload.hpp>
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/profiling/ArmNNProfiling.hpp>
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/MemSyncWorkload.hpp>
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker #include <common/include/Processes.hpp>
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker namespace armnn
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker using namespace std;
37*89c4ff92SAndroid Build Coastguard Worker using namespace arm::pipe;
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker namespace
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker template <typename ExceptionType>
ToErrorMessage(const char * prefix,const ExceptionType & error)43*89c4ff92SAndroid Build Coastguard Worker std::string ToErrorMessage(const char * prefix, const ExceptionType & error)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
46*89c4ff92SAndroid Build Coastguard Worker ss << prefix << " " << error.what();
47*89c4ff92SAndroid Build Coastguard Worker return ss.str();
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
AddLayerStructure(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,const Layer & layer,ProfilingGuid networkGuid)50*89c4ff92SAndroid Build Coastguard Worker void AddLayerStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
51*89c4ff92SAndroid Build Coastguard Worker const Layer& layer,
52*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid networkGuid)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker // Add layer to the post-optimisation network structure
55*89c4ff92SAndroid Build Coastguard Worker std::string layerName = layer.GetNameStr().empty() ? "<Unnamed>" : layer.GetNameStr();
56*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateNamedTypedChildEntity(layer.GetGuid(),
57*89c4ff92SAndroid Build Coastguard Worker networkGuid,
58*89c4ff92SAndroid Build Coastguard Worker layerName,
59*89c4ff92SAndroid Build Coastguard Worker LabelsAndEventClasses::LAYER_GUID);
60*89c4ff92SAndroid Build Coastguard Worker for (auto&& input : layer.GetInputSlots())
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker const IOutputSlot* source = input.GetConnectedOutputSlot();
63*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(source != NULL);
64*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateConnectionRelationship(ProfilingRelationshipType::RetentionLink,
65*89c4ff92SAndroid Build Coastguard Worker source->GetOwningLayerGuid(),
66*89c4ff92SAndroid Build Coastguard Worker layer.GetGuid());
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker }
69*89c4ff92SAndroid Build Coastguard Worker
AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,std::unique_ptr<IWorkload> & workload,const Layer & layer)70*89c4ff92SAndroid Build Coastguard Worker void AddWorkloadStructure(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
71*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload>& workload,
72*89c4ff92SAndroid Build Coastguard Worker const Layer& layer)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker // Add workload to the post-optimisation network structure
75*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateTypedEntity(workload->GetGuid(), LabelsAndEventClasses::WORKLOAD_GUID);
76*89c4ff92SAndroid Build Coastguard Worker timelineUtils->MarkEntityWithLabel(workload->GetGuid(),
77*89c4ff92SAndroid Build Coastguard Worker layer.GetBackendId().Get(),
78*89c4ff92SAndroid Build Coastguard Worker LabelsAndEventClasses::BACKENDID_GUID);
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker // Link the workload to the layer
81*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
82*89c4ff92SAndroid Build Coastguard Worker layer.GetGuid(),
83*89c4ff92SAndroid Build Coastguard Worker workload->GetGuid(),
84*89c4ff92SAndroid Build Coastguard Worker LabelsAndEventClasses::CHILD_GUID);
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker } // anonymous
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker /**
90*89c4ff92SAndroid Build Coastguard Worker * This function performs a sanity check to ensure that the combination of input and output memory source matches the
91*89c4ff92SAndroid Build Coastguard Worker * values for importEnabled and exportEnabled that were specified during optimization. During optimization the tensor
92*89c4ff92SAndroid Build Coastguard Worker * handle factories are chosen based on whether import and export are enabled. If the user then specifies something
93*89c4ff92SAndroid Build Coastguard Worker * incompatible here it can lead to problems.
94*89c4ff92SAndroid Build Coastguard Worker *
95*89c4ff92SAndroid Build Coastguard Worker * @param optimizedOptions
96*89c4ff92SAndroid Build Coastguard Worker * @param networkProperties
97*89c4ff92SAndroid Build Coastguard Worker */
ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,const INetworkProperties & networkProperties)98*89c4ff92SAndroid Build Coastguard Worker void ValidateSourcesMatchOptimizedNetwork(std::vector<BackendOptions> optimizedOptions,
99*89c4ff92SAndroid Build Coastguard Worker const INetworkProperties& networkProperties)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker // Find the "Global" backend options. During the optimize phase the values of importEnabled and exportEnabled are
102*89c4ff92SAndroid Build Coastguard Worker // added as backend options.
103*89c4ff92SAndroid Build Coastguard Worker const vector<BackendOptions>::iterator& backendItr =
104*89c4ff92SAndroid Build Coastguard Worker find_if(optimizedOptions.begin(), optimizedOptions.end(), [](const BackendOptions& backend) {
105*89c4ff92SAndroid Build Coastguard Worker if (backend.GetBackendId().Get() == "Global")
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker return true;
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker else
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker return false;
112*89c4ff92SAndroid Build Coastguard Worker }
113*89c4ff92SAndroid Build Coastguard Worker });
114*89c4ff92SAndroid Build Coastguard Worker bool importEnabled = false;
115*89c4ff92SAndroid Build Coastguard Worker bool exportEnabled = false;
116*89c4ff92SAndroid Build Coastguard Worker if (backendItr != optimizedOptions.end())
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker // Find the importEnabled and exportEnabled values.
119*89c4ff92SAndroid Build Coastguard Worker for (size_t i = 0; i < backendItr->GetOptionCount(); i++)
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker const BackendOptions::BackendOption& option = backendItr->GetOption(i);
122*89c4ff92SAndroid Build Coastguard Worker if (option.GetName() == "ImportEnabled")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker importEnabled = option.GetValue().AsBool();
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker if (option.GetName() == "ExportEnabled")
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker exportEnabled = option.GetValue().AsBool();
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker
133*89c4ff92SAndroid Build Coastguard Worker // Now that we have values for import and export compare them to the MemorySource variables.
134*89c4ff92SAndroid Build Coastguard Worker // Any value of MemorySource that's not "Undefined" implies that we need to do an import of some kind.
135*89c4ff92SAndroid Build Coastguard Worker if ((networkProperties.m_InputSource == MemorySource::Undefined && importEnabled) ||
136*89c4ff92SAndroid Build Coastguard Worker (networkProperties.m_InputSource != MemorySource::Undefined && !importEnabled))
137*89c4ff92SAndroid Build Coastguard Worker {
138*89c4ff92SAndroid Build Coastguard Worker auto message = fmt::format("The input memory source specified, '{0}',", networkProperties.m_InputSource);
139*89c4ff92SAndroid Build Coastguard Worker if (!importEnabled)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker message.append(" requires that memory import be enabled. However, "
142*89c4ff92SAndroid Build Coastguard Worker "it was disabled when this network was optimized.");
143*89c4ff92SAndroid Build Coastguard Worker }
144*89c4ff92SAndroid Build Coastguard Worker else
145*89c4ff92SAndroid Build Coastguard Worker {
146*89c4ff92SAndroid Build Coastguard Worker message.append(" requires that memory import be disabled. However, "
147*89c4ff92SAndroid Build Coastguard Worker "it was enabled when this network was optimized.");
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(message);
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker
152*89c4ff92SAndroid Build Coastguard Worker if ((networkProperties.m_OutputSource == MemorySource::Undefined && exportEnabled) ||
153*89c4ff92SAndroid Build Coastguard Worker (networkProperties.m_OutputSource != MemorySource::Undefined && !exportEnabled))
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker auto message = fmt::format("The output memory source specified, '{0}',", networkProperties.m_OutputSource);
156*89c4ff92SAndroid Build Coastguard Worker if (!exportEnabled)
157*89c4ff92SAndroid Build Coastguard Worker {
158*89c4ff92SAndroid Build Coastguard Worker message.append(" requires that memory export be enabled. However, "
159*89c4ff92SAndroid Build Coastguard Worker "it was disabled when this network was optimized.");
160*89c4ff92SAndroid Build Coastguard Worker }
161*89c4ff92SAndroid Build Coastguard Worker else
162*89c4ff92SAndroid Build Coastguard Worker {
163*89c4ff92SAndroid Build Coastguard Worker message.append(" requires that memory export be disabled. However, "
164*89c4ff92SAndroid Build Coastguard Worker "it was enabled when this network was optimized.");
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(message);
167*89c4ff92SAndroid Build Coastguard Worker }
168*89c4ff92SAndroid Build Coastguard Worker } // anonymous
169*89c4ff92SAndroid Build Coastguard Worker
MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,std::string & errorMessage,const INetworkProperties & networkProperties,arm::pipe::IProfilingService * profilingService)170*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<LoadedNetwork> LoadedNetwork::MakeLoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
171*89c4ff92SAndroid Build Coastguard Worker std::string& errorMessage,
172*89c4ff92SAndroid Build Coastguard Worker const INetworkProperties& networkProperties,
173*89c4ff92SAndroid Build Coastguard Worker arm::pipe::IProfilingService* profilingService)
174*89c4ff92SAndroid Build Coastguard Worker {
175*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<LoadedNetwork> loadedNetwork;
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker auto Fail = [&](const std::exception& error) -> std::unique_ptr<LoadedNetwork>
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker errorMessage = ToErrorMessage("An error occurred when preparing the network workloads: ", error);
180*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << errorMessage;
181*89c4ff92SAndroid Build Coastguard Worker
182*89c4ff92SAndroid Build Coastguard Worker return std::unique_ptr<LoadedNetwork>();
183*89c4ff92SAndroid Build Coastguard Worker };
184*89c4ff92SAndroid Build Coastguard Worker
185*89c4ff92SAndroid Build Coastguard Worker try
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker loadedNetwork.reset(new LoadedNetwork(std::move(net), networkProperties, profilingService));
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::RuntimeException& error)
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker return Fail(error);
192*89c4ff92SAndroid Build Coastguard Worker }
193*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::Exception& error)
194*89c4ff92SAndroid Build Coastguard Worker {
195*89c4ff92SAndroid Build Coastguard Worker return Fail(error);
196*89c4ff92SAndroid Build Coastguard Worker }
197*89c4ff92SAndroid Build Coastguard Worker catch (const std::runtime_error& error)
198*89c4ff92SAndroid Build Coastguard Worker {
199*89c4ff92SAndroid Build Coastguard Worker return Fail(error);
200*89c4ff92SAndroid Build Coastguard Worker }
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker return loadedNetwork;
203*89c4ff92SAndroid Build Coastguard Worker }
204*89c4ff92SAndroid Build Coastguard Worker
LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,const INetworkProperties & networkProperties,arm::pipe::IProfilingService * profilingService)205*89c4ff92SAndroid Build Coastguard Worker LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
206*89c4ff92SAndroid Build Coastguard Worker const INetworkProperties& networkProperties,
207*89c4ff92SAndroid Build Coastguard Worker arm::pipe::IProfilingService* profilingService) :
208*89c4ff92SAndroid Build Coastguard Worker m_OptimizedNetwork(std::move(net)),
209*89c4ff92SAndroid Build Coastguard Worker m_NetworkProperties(networkProperties),
210*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry(),
211*89c4ff92SAndroid Build Coastguard Worker m_ProfilingService(profilingService)
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadedNetwork");
214*89c4ff92SAndroid Build Coastguard Worker // Get the profiler and register it for the current thread.
215*89c4ff92SAndroid Build Coastguard Worker const std::shared_ptr<IProfiler>& profiler = m_OptimizedNetwork->GetProfiler();
216*89c4ff92SAndroid Build Coastguard Worker ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
217*89c4ff92SAndroid Build Coastguard Worker
218*89c4ff92SAndroid Build Coastguard Worker profiler->EnableProfiling(networkProperties.m_ProfilingEnabled);
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker profiler->EnableNetworkDetailsToStdOut(networkProperties.m_OutputNetworkDetailsMethod);
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker // We need to check that the memory sources match up with the values of import and export specified during the
223*89c4ff92SAndroid Build Coastguard Worker // optimize phase. If they don't this will throw an exception.
224*89c4ff92SAndroid Build Coastguard Worker ValidateSourcesMatchOptimizedNetwork(m_OptimizedNetwork.get()->pOptimizedNetworkImpl->GetModelOptions(),
225*89c4ff92SAndroid Build Coastguard Worker m_NetworkProperties);
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker //First create tensor handlers, backends and workload factories.
228*89c4ff92SAndroid Build Coastguard Worker //Handlers are created before workloads are.
229*89c4ff92SAndroid Build Coastguard Worker //Because workload creation can modify some of the handlers,
230*89c4ff92SAndroid Build Coastguard Worker //(for example the splitter and concat layers).
231*89c4ff92SAndroid Build Coastguard Worker
232*89c4ff92SAndroid Build Coastguard Worker bool useExternalMemoryManager = false;
233*89c4ff92SAndroid Build Coastguard Worker bool useInternalMemoryManager = false;
234*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
235*89c4ff92SAndroid Build Coastguard Worker // Ensure Topological order
236*89c4ff92SAndroid Build Coastguard Worker order.SetLayersOutOfOrder();
237*89c4ff92SAndroid Build Coastguard Worker order.TopologicalSort();
238*89c4ff92SAndroid Build Coastguard Worker
239*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker m_IsInputImported = std::vector<bool>(order.GetNumInputs(), false);
242*89c4ff92SAndroid Build Coastguard Worker m_IsOutputImported = std::vector<bool>(order.GetNumOutputs(), false);
243*89c4ff92SAndroid Build Coastguard Worker }
244*89c4ff92SAndroid Build Coastguard Worker
245*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
246*89c4ff92SAndroid Build Coastguard Worker {
247*89c4ff92SAndroid Build Coastguard Worker auto const& backendId = layer->GetBackendId();
248*89c4ff92SAndroid Build Coastguard Worker if (m_Backends.count(backendId) == 0)
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker auto createBackend = BackendRegistryInstance().GetFactory(backendId);
251*89c4ff92SAndroid Build Coastguard Worker auto it = m_Backends.emplace(std::make_pair(backendId, createBackend()));
252*89c4ff92SAndroid Build Coastguard Worker
253*89c4ff92SAndroid Build Coastguard Worker IBackendInternal* backend = it.first->second.get();
254*89c4ff92SAndroid Build Coastguard Worker
255*89c4ff92SAndroid Build Coastguard Worker // If we're doing async execution verify that the backend supports it and ExternallyManagedMemory.
256*89c4ff92SAndroid Build Coastguard Worker if (networkProperties.m_AsyncEnabled)
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker if (!HasCapability(BackendOptions::BackendOption{"AsyncExecution", true}, backend->GetCapabilities()))
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker std::string er = backend->GetId();
261*89c4ff92SAndroid Build Coastguard Worker er += " does not support AsyncExecution";
262*89c4ff92SAndroid Build Coastguard Worker throw BackendCapabilityException(er);
263*89c4ff92SAndroid Build Coastguard Worker }
264*89c4ff92SAndroid Build Coastguard Worker if (!HasCapability(BackendOptions::BackendOption{"ExternallyManagedMemory", true},
265*89c4ff92SAndroid Build Coastguard Worker backend->GetCapabilities()))
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker std::string er = backend->GetId();
268*89c4ff92SAndroid Build Coastguard Worker er += " does not support ExternallyManagedMemory\n";
269*89c4ff92SAndroid Build Coastguard Worker er += "AsyncEnabled networks require all backends to support ExternallyManagedMemory";
270*89c4ff92SAndroid Build Coastguard Worker throw BackendCapabilityException(er);
271*89c4ff92SAndroid Build Coastguard Worker }
272*89c4ff92SAndroid Build Coastguard Worker m_SupportsExternallyManagedMemory[backend->GetId()] = true;
273*89c4ff92SAndroid Build Coastguard Worker useExternalMemoryManager = true;
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker else
276*89c4ff92SAndroid Build Coastguard Worker {
277*89c4ff92SAndroid Build Coastguard Worker m_SupportsExternallyManagedMemory[backend->GetId()] = false;
278*89c4ff92SAndroid Build Coastguard Worker useInternalMemoryManager = true;
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker
281*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr workloadFactory;
282*89c4ff92SAndroid Build Coastguard Worker if (backend->SupportsTensorAllocatorAPI())
283*89c4ff92SAndroid Build Coastguard Worker {
284*89c4ff92SAndroid Build Coastguard Worker workloadFactory = backend->CreateWorkloadFactory(
285*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry,
286*89c4ff92SAndroid Build Coastguard Worker m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions(),
287*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(m_NetworkProperties.m_InputSource),
288*89c4ff92SAndroid Build Coastguard Worker static_cast<MemorySourceFlags>(m_NetworkProperties.m_OutputSource));
289*89c4ff92SAndroid Build Coastguard Worker }
290*89c4ff92SAndroid Build Coastguard Worker else
291*89c4ff92SAndroid Build Coastguard Worker {
292*89c4ff92SAndroid Build Coastguard Worker m_BackendMemoryMangers.emplace_back(backend->CreateMemoryManager());
293*89c4ff92SAndroid Build Coastguard Worker workloadFactory = backend->CreateWorkloadFactory(
294*89c4ff92SAndroid Build Coastguard Worker m_BackendMemoryMangers.back(), m_OptimizedNetwork->pOptimizedNetworkImpl->GetModelOptions());
295*89c4ff92SAndroid Build Coastguard Worker }
296*89c4ff92SAndroid Build Coastguard Worker m_WorkloadFactories[backendId ] = std::move(workloadFactory);
297*89c4ff92SAndroid Build Coastguard Worker }
298*89c4ff92SAndroid Build Coastguard Worker }
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
301*89c4ff92SAndroid Build Coastguard Worker {
302*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
303*89c4ff92SAndroid Build Coastguard Worker {
304*89c4ff92SAndroid Build Coastguard Worker auto& workloadFactory = GetWorkloadFactory(*layer);
305*89c4ff92SAndroid Build Coastguard Worker bool supportsExternalManager = m_SupportsExternallyManagedMemory[layer->GetBackendId()];
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType())
308*89c4ff92SAndroid Build Coastguard Worker {
309*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
310*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemImport:
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker // If IsImportEnabled is true then we need to set IsMemoryManaged
313*89c4ff92SAndroid Build Coastguard Worker // to false when creating TensorHandles
314*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
315*89c4ff92SAndroid Build Coastguard Worker workloadFactory,
316*89c4ff92SAndroid Build Coastguard Worker !supportsExternalManager && !m_NetworkProperties.m_ImportEnabled);
317*89c4ff92SAndroid Build Coastguard Worker break;
318*89c4ff92SAndroid Build Coastguard Worker }
319*89c4ff92SAndroid Build Coastguard Worker case LayerType::Constant:
320*89c4ff92SAndroid Build Coastguard Worker {
321*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory, true);
322*89c4ff92SAndroid Build Coastguard Worker break;
323*89c4ff92SAndroid Build Coastguard Worker }
324*89c4ff92SAndroid Build Coastguard Worker default:
325*89c4ff92SAndroid Build Coastguard Worker {
326*89c4ff92SAndroid Build Coastguard Worker // Look for a layer with 1 OutputSlot which has 1 connection and that connection is an Output Layer
327*89c4ff92SAndroid Build Coastguard Worker // If Export is enabled disable memory management so we can export, otherwise we do a copy
328*89c4ff92SAndroid Build Coastguard Worker if ((layer->GetNumOutputSlots() == 1) &&
329*89c4ff92SAndroid Build Coastguard Worker (layer->GetOutputSlots()[0].GetNumConnections() == 1) &&
330*89c4ff92SAndroid Build Coastguard Worker (layer->GetOutputSlots()[0].GetConnection(0)->GetOwningLayer().GetType() == LayerType::Output))
331*89c4ff92SAndroid Build Coastguard Worker {
332*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
333*89c4ff92SAndroid Build Coastguard Worker workloadFactory,
334*89c4ff92SAndroid Build Coastguard Worker !supportsExternalManager && !m_NetworkProperties.m_ExportEnabled);
335*89c4ff92SAndroid Build Coastguard Worker }
336*89c4ff92SAndroid Build Coastguard Worker else
337*89c4ff92SAndroid Build Coastguard Worker {
338*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(m_TensorHandleFactoryRegistry,
339*89c4ff92SAndroid Build Coastguard Worker workloadFactory,
340*89c4ff92SAndroid Build Coastguard Worker !supportsExternalManager);
341*89c4ff92SAndroid Build Coastguard Worker }
342*89c4ff92SAndroid Build Coastguard Worker }
343*89c4ff92SAndroid Build Coastguard Worker }
344*89c4ff92SAndroid Build Coastguard Worker }
345*89c4ff92SAndroid Build Coastguard Worker }
346*89c4ff92SAndroid Build Coastguard Worker
347*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
348*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
349*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
350*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
351*89c4ff92SAndroid Build Coastguard Worker {
352*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID);
353*89c4ff92SAndroid Build Coastguard Worker // Mark the network with a start of life event
354*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEvent(networkGuid, LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
355*89c4ff92SAndroid Build Coastguard Worker // and with the process ID
356*89c4ff92SAndroid Build Coastguard Worker int processID = arm::pipe::GetCurrentProcessId();
357*89c4ff92SAndroid Build Coastguard Worker std::stringstream ss;
358*89c4ff92SAndroid Build Coastguard Worker ss << processID;
359*89c4ff92SAndroid Build Coastguard Worker timelineUtils->MarkEntityWithLabel(networkGuid, ss.str(), LabelsAndEventClasses::PROCESS_ID_GUID);
360*89c4ff92SAndroid Build Coastguard Worker }
361*89c4ff92SAndroid Build Coastguard Worker
362*89c4ff92SAndroid Build Coastguard Worker std::vector<IWorkload*> ConstWorkloads;
363*89c4ff92SAndroid Build Coastguard Worker
364*89c4ff92SAndroid Build Coastguard Worker //Then create workloads.
365*89c4ff92SAndroid Build Coastguard Worker {
366*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_CreateWorkloads");
367*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer: order)
368*89c4ff92SAndroid Build Coastguard Worker {
369*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
370*89c4ff92SAndroid Build Coastguard Worker {
371*89c4ff92SAndroid Build Coastguard Worker // Add layer to the post-optimisation network structure
372*89c4ff92SAndroid Build Coastguard Worker AddLayerStructure(timelineUtils, *layer, networkGuid);
373*89c4ff92SAndroid Build Coastguard Worker }
374*89c4ff92SAndroid Build Coastguard Worker
375*89c4ff92SAndroid Build Coastguard Worker const IWorkloadFactory& workloadFactory = GetWorkloadFactory(*layer);
376*89c4ff92SAndroid Build Coastguard Worker
377*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType())
378*89c4ff92SAndroid Build Coastguard Worker {
379*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
380*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
381*89c4ff92SAndroid Build Coastguard Worker {
382*89c4ff92SAndroid Build Coastguard Worker // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput().
383*89c4ff92SAndroid Build Coastguard Worker break;
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker default:
386*89c4ff92SAndroid Build Coastguard Worker {
387*89c4ff92SAndroid Build Coastguard Worker auto workload = layer->CreateWorkload(workloadFactory);
388*89c4ff92SAndroid Build Coastguard Worker
389*89c4ff92SAndroid Build Coastguard Worker if (!workload)
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker const char* const layerName =
392*89c4ff92SAndroid Build Coastguard Worker layer->GetNameStr().length() != 0 ? layer->GetName() : "<Unnamed>";
393*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
394*89c4ff92SAndroid Build Coastguard Worker fmt::format("No workload created for layer (name: '{0}' type: '{1}') (compute '{2}')",
395*89c4ff92SAndroid Build Coastguard Worker layerName, static_cast<int>(layer->GetType()), layer->GetBackendId().Get()
396*89c4ff92SAndroid Build Coastguard Worker ));
397*89c4ff92SAndroid Build Coastguard Worker }
398*89c4ff92SAndroid Build Coastguard Worker
399*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
400*89c4ff92SAndroid Build Coastguard Worker {
401*89c4ff92SAndroid Build Coastguard Worker // Add workload to the post-optimisation network structure
402*89c4ff92SAndroid Build Coastguard Worker AddWorkloadStructure(timelineUtils, workload, *layer);
403*89c4ff92SAndroid Build Coastguard Worker }
404*89c4ff92SAndroid Build Coastguard Worker
405*89c4ff92SAndroid Build Coastguard Worker // For async networks ConstantWorkloads are managed exclusively by LoadedNetwork
406*89c4ff92SAndroid Build Coastguard Worker // and are separated out from the other workloads
407*89c4ff92SAndroid Build Coastguard Worker if((networkProperties.m_AsyncEnabled || useExternalMemoryManager) &&
408*89c4ff92SAndroid Build Coastguard Worker layer->GetType() == LayerType::Constant)
409*89c4ff92SAndroid Build Coastguard Worker {
410*89c4ff92SAndroid Build Coastguard Worker m_ConstantTensorHandles[layer->GetGuid()] =
411*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).GetOutputHandler().GetData();
412*89c4ff92SAndroid Build Coastguard Worker m_ConstantWorkloads[layer->GetGuid()] = std::move(workload);
413*89c4ff92SAndroid Build Coastguard Worker }
414*89c4ff92SAndroid Build Coastguard Worker else
415*89c4ff92SAndroid Build Coastguard Worker {
416*89c4ff92SAndroid Build Coastguard Worker m_WorkloadQueue.push_back(std::move(workload));
417*89c4ff92SAndroid Build Coastguard Worker
418*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::Constant)
419*89c4ff92SAndroid Build Coastguard Worker {
420*89c4ff92SAndroid Build Coastguard Worker // Place the Constant Workloads into a queue so that they can be executed first
421*89c4ff92SAndroid Build Coastguard Worker ConstWorkloads.push_back(m_WorkloadQueue.back().get());
422*89c4ff92SAndroid Build Coastguard Worker }
423*89c4ff92SAndroid Build Coastguard Worker }
424*89c4ff92SAndroid Build Coastguard Worker // release the constant data in the layer.
425*89c4ff92SAndroid Build Coastguard Worker layer->ReleaseConstantData();
426*89c4ff92SAndroid Build Coastguard Worker break;
427*89c4ff92SAndroid Build Coastguard Worker }
428*89c4ff92SAndroid Build Coastguard Worker }
429*89c4ff92SAndroid Build Coastguard Worker }
430*89c4ff92SAndroid Build Coastguard Worker }
431*89c4ff92SAndroid Build Coastguard Worker
432*89c4ff92SAndroid Build Coastguard Worker // Gather information about workloads for inputs & outputs
433*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled && m_WorkloadQueue.size() != 0)
434*89c4ff92SAndroid Build Coastguard Worker {
435*89c4ff92SAndroid Build Coastguard Worker const int noOfInputs = armnn::numeric_cast<int>(order.GetNumInputs());
436*89c4ff92SAndroid Build Coastguard Worker
437*89c4ff92SAndroid Build Coastguard Worker // Get indices of all workloads connected to each input and
438*89c4ff92SAndroid Build Coastguard Worker // check if they support tensor handle replacement
439*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* layer: order.GetInputLayers())
440*89c4ff92SAndroid Build Coastguard Worker {
441*89c4ff92SAndroid Build Coastguard Worker const auto bindingId = layer->GetBindingId();
442*89c4ff92SAndroid Build Coastguard Worker
443*89c4ff92SAndroid Build Coastguard Worker bool supportsReplacement = true;
444*89c4ff92SAndroid Build Coastguard Worker
445*89c4ff92SAndroid Build Coastguard Worker for (const auto inputSlot: layer->GetOutputSlot(0).GetConnections())
446*89c4ff92SAndroid Build Coastguard Worker {
447*89c4ff92SAndroid Build Coastguard Worker auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(inputSlot->GetOwningLayer()));
448*89c4ff92SAndroid Build Coastguard Worker workloadIndex -= noOfInputs;
449*89c4ff92SAndroid Build Coastguard Worker
450*89c4ff92SAndroid Build Coastguard Worker m_InputWorkloadSlotPairs[bindingId].emplace_back(WorkloadIndices{
451*89c4ff92SAndroid Build Coastguard Worker armnn::numeric_cast<unsigned int>(workloadIndex), inputSlot->GetSlotIndex()});
452*89c4ff92SAndroid Build Coastguard Worker
453*89c4ff92SAndroid Build Coastguard Worker auto workload = m_WorkloadQueue[m_InputWorkloadSlotPairs[bindingId].back().m_WorkloadIndex].get();
454*89c4ff92SAndroid Build Coastguard Worker supportsReplacement &= workload->SupportsTensorHandleReplacement();
455*89c4ff92SAndroid Build Coastguard Worker }
456*89c4ff92SAndroid Build Coastguard Worker
457*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId factoryId = layer->GetOutputSlot(0).GetTensorHandleFactoryId();
458*89c4ff92SAndroid Build Coastguard Worker // Get matching import factory Id
459*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId importFactoryId =
460*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
461*89c4ff92SAndroid Build Coastguard Worker
462*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
463*89c4ff92SAndroid Build Coastguard Worker
464*89c4ff92SAndroid Build Coastguard Worker if (supportsReplacement && importFactory)
465*89c4ff92SAndroid Build Coastguard Worker {
466*89c4ff92SAndroid Build Coastguard Worker m_PreImportedInputHandles.emplace_back(
467*89c4ff92SAndroid Build Coastguard Worker bindingId, importFactory->CreateTensorHandle(layer->GetOutputSlot(0).GetTensorInfo(), false));
468*89c4ff92SAndroid Build Coastguard Worker }
469*89c4ff92SAndroid Build Coastguard Worker else
470*89c4ff92SAndroid Build Coastguard Worker {
471*89c4ff92SAndroid Build Coastguard Worker m_PreImportedInputHandles.emplace_back(bindingId, nullptr);
472*89c4ff92SAndroid Build Coastguard Worker }
473*89c4ff92SAndroid Build Coastguard Worker }
474*89c4ff92SAndroid Build Coastguard Worker
475*89c4ff92SAndroid Build Coastguard Worker // Get indices of all workloads connected to each output and
476*89c4ff92SAndroid Build Coastguard Worker // check if they support tensor handle replacement
477*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* layer: order.GetOutputLayers())
478*89c4ff92SAndroid Build Coastguard Worker {
479*89c4ff92SAndroid Build Coastguard Worker const auto bindingId = layer->GetBindingId();
480*89c4ff92SAndroid Build Coastguard Worker
481*89c4ff92SAndroid Build Coastguard Worker const auto outputSlot = layer->GetInputSlot(0).GetConnectedOutputSlot();
482*89c4ff92SAndroid Build Coastguard Worker auto& indices = m_OutputWorkloadSlotPairs[bindingId];
483*89c4ff92SAndroid Build Coastguard Worker
484*89c4ff92SAndroid Build Coastguard Worker auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(outputSlot->GetOwningLayer()));
485*89c4ff92SAndroid Build Coastguard Worker workloadIndex -= noOfInputs;
486*89c4ff92SAndroid Build Coastguard Worker
487*89c4ff92SAndroid Build Coastguard Worker indices.m_OutputSlotIndices = WorkloadIndices{numeric_cast<unsigned int>(workloadIndex),
488*89c4ff92SAndroid Build Coastguard Worker outputSlot->CalculateIndexOnOwner()};
489*89c4ff92SAndroid Build Coastguard Worker
490*89c4ff92SAndroid Build Coastguard Worker bool supportsReplacement = true;
491*89c4ff92SAndroid Build Coastguard Worker auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
492*89c4ff92SAndroid Build Coastguard Worker supportsReplacement &= outputWorkload->SupportsTensorHandleReplacement();
493*89c4ff92SAndroid Build Coastguard Worker
494*89c4ff92SAndroid Build Coastguard Worker for (auto &inputSlot: outputSlot->GetConnections())
495*89c4ff92SAndroid Build Coastguard Worker {
496*89c4ff92SAndroid Build Coastguard Worker if(inputSlot->GetOwningLayer().GetType() != LayerType::Output)
497*89c4ff92SAndroid Build Coastguard Worker {
498*89c4ff92SAndroid Build Coastguard Worker auto inWorkloadIndex = std::distance(order.begin(),
499*89c4ff92SAndroid Build Coastguard Worker order.GetPosInGraph(inputSlot->GetOwningLayer()));
500*89c4ff92SAndroid Build Coastguard Worker inWorkloadIndex -= noOfInputs;
501*89c4ff92SAndroid Build Coastguard Worker indices.m_InputSlotIndices.emplace_back(WorkloadIndices{numeric_cast<unsigned int>(inWorkloadIndex),
502*89c4ff92SAndroid Build Coastguard Worker inputSlot->GetSlotIndex()});
503*89c4ff92SAndroid Build Coastguard Worker auto inputWorkload = m_WorkloadQueue[indices.m_InputSlotIndices.back().m_WorkloadIndex].get();
504*89c4ff92SAndroid Build Coastguard Worker supportsReplacement &= inputWorkload->SupportsTensorHandleReplacement();
505*89c4ff92SAndroid Build Coastguard Worker }
506*89c4ff92SAndroid Build Coastguard Worker }
507*89c4ff92SAndroid Build Coastguard Worker
508*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId();
509*89c4ff92SAndroid Build Coastguard Worker // Get matching import factory Id
510*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId importFactoryId =
511*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
512*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
513*89c4ff92SAndroid Build Coastguard Worker
514*89c4ff92SAndroid Build Coastguard Worker if (supportsReplacement && importFactory)
515*89c4ff92SAndroid Build Coastguard Worker {
516*89c4ff92SAndroid Build Coastguard Worker m_PreImportedOutputHandles.emplace_back(
517*89c4ff92SAndroid Build Coastguard Worker bindingId, importFactory->CreateTensorHandle(outputSlot->GetTensorInfo(), false));
518*89c4ff92SAndroid Build Coastguard Worker }
519*89c4ff92SAndroid Build Coastguard Worker else
520*89c4ff92SAndroid Build Coastguard Worker {
521*89c4ff92SAndroid Build Coastguard Worker m_PreImportedOutputHandles.emplace_back(bindingId, nullptr);
522*89c4ff92SAndroid Build Coastguard Worker }
523*89c4ff92SAndroid Build Coastguard Worker }
524*89c4ff92SAndroid Build Coastguard Worker }
525*89c4ff92SAndroid Build Coastguard Worker
526*89c4ff92SAndroid Build Coastguard Worker for (auto&& workloadFactory : m_WorkloadFactories)
527*89c4ff92SAndroid Build Coastguard Worker {
528*89c4ff92SAndroid Build Coastguard Worker workloadFactory.second->AfterWorkloadsCreated();
529*89c4ff92SAndroid Build Coastguard Worker }
530*89c4ff92SAndroid Build Coastguard Worker
531*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
532*89c4ff92SAndroid Build Coastguard Worker {
533*89c4ff92SAndroid Build Coastguard Worker // Commit to send the post-optimisation network structure
534*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
535*89c4ff92SAndroid Build Coastguard Worker }
536*89c4ff92SAndroid Build Coastguard Worker
537*89c4ff92SAndroid Build Coastguard Worker if (useExternalMemoryManager)
538*89c4ff92SAndroid Build Coastguard Worker {
539*89c4ff92SAndroid Build Coastguard Worker if (networkProperties.m_AsyncEnabled)
540*89c4ff92SAndroid Build Coastguard Worker {
541*89c4ff92SAndroid Build Coastguard Worker CreateMemoryProfileAsync();
542*89c4ff92SAndroid Build Coastguard Worker }
543*89c4ff92SAndroid Build Coastguard Worker else
544*89c4ff92SAndroid Build Coastguard Worker {
545*89c4ff92SAndroid Build Coastguard Worker CreateMemoryProfile();
546*89c4ff92SAndroid Build Coastguard Worker }
547*89c4ff92SAndroid Build Coastguard Worker
548*89c4ff92SAndroid Build Coastguard Worker auto backendStrategyMap = BackendRegistryInstance().GetMemoryOptimizerStrategies();
549*89c4ff92SAndroid Build Coastguard Worker for (auto& backendMemoryProfile : m_MemBlockMap)
550*89c4ff92SAndroid Build Coastguard Worker {
551*89c4ff92SAndroid Build Coastguard Worker const BackendId& backendId = backendMemoryProfile.first;
552*89c4ff92SAndroid Build Coastguard Worker if (backendStrategyMap.find(backendId) != backendStrategyMap.end())
553*89c4ff92SAndroid Build Coastguard Worker {
554*89c4ff92SAndroid Build Coastguard Worker m_MemBinMap[backendId] = backendStrategyMap[backendId]->Optimize(backendMemoryProfile.second);
555*89c4ff92SAndroid Build Coastguard Worker }
556*89c4ff92SAndroid Build Coastguard Worker else
557*89c4ff92SAndroid Build Coastguard Worker {
558*89c4ff92SAndroid Build Coastguard Worker m_MemBinMap[backendId] = m_ConstantStrategy->Optimize(backendMemoryProfile.second);
559*89c4ff92SAndroid Build Coastguard Worker }
560*89c4ff92SAndroid Build Coastguard Worker }
561*89c4ff92SAndroid Build Coastguard Worker
562*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
563*89c4ff92SAndroid Build Coastguard Worker {
564*89c4ff92SAndroid Build Coastguard Worker m_ExternalMemoryManager = CreateExternalMemoryManger(m_TensorMemory);
565*89c4ff92SAndroid Build Coastguard Worker
566*89c4ff92SAndroid Build Coastguard Worker // Sort m_TensorMemory, so it's order matches m_Tensorhandles
567*89c4ff92SAndroid Build Coastguard Worker std::sort(m_TensorMemory.begin(), m_TensorMemory.end(),
568*89c4ff92SAndroid Build Coastguard Worker [](const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& lhs,
569*89c4ff92SAndroid Build Coastguard Worker const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& rhs)
570*89c4ff92SAndroid Build Coastguard Worker {
571*89c4ff92SAndroid Build Coastguard Worker return lhs.first->m_OutputSlotId < rhs.first->m_OutputSlotId;
572*89c4ff92SAndroid Build Coastguard Worker });
573*89c4ff92SAndroid Build Coastguard Worker }
574*89c4ff92SAndroid Build Coastguard Worker }
575*89c4ff92SAndroid Build Coastguard Worker
576*89c4ff92SAndroid Build Coastguard Worker // Now that the intermediate tensor memory has been set-up,
577*89c4ff92SAndroid Build Coastguard Worker // do any post allocation configuration for each workload.
578*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
579*89c4ff92SAndroid Build Coastguard Worker {
580*89c4ff92SAndroid Build Coastguard Worker if (useInternalMemoryManager)
581*89c4ff92SAndroid Build Coastguard Worker {
582*89c4ff92SAndroid Build Coastguard Worker // Set up memory.
583*89c4ff92SAndroid Build Coastguard Worker m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().AllocateDynamicBuffers();
584*89c4ff92SAndroid Build Coastguard Worker }
585*89c4ff92SAndroid Build Coastguard Worker
586*89c4ff92SAndroid Build Coastguard Worker for (auto &workload : m_WorkloadQueue)
587*89c4ff92SAndroid Build Coastguard Worker {
588*89c4ff92SAndroid Build Coastguard Worker workload->PostAllocationConfigure();
589*89c4ff92SAndroid Build Coastguard Worker }
590*89c4ff92SAndroid Build Coastguard Worker }
591*89c4ff92SAndroid Build Coastguard Worker
592*89c4ff92SAndroid Build Coastguard Worker if (useExternalMemoryManager)
593*89c4ff92SAndroid Build Coastguard Worker {
594*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
595*89c4ff92SAndroid Build Coastguard Worker {
596*89c4ff92SAndroid Build Coastguard Worker AllocateAndExecuteConstantWorkloads();
597*89c4ff92SAndroid Build Coastguard Worker }
598*89c4ff92SAndroid Build Coastguard Worker else
599*89c4ff92SAndroid Build Coastguard Worker {
600*89c4ff92SAndroid Build Coastguard Worker AllocateAndExecuteConstantWorkloadsAsync();
601*89c4ff92SAndroid Build Coastguard Worker }
602*89c4ff92SAndroid Build Coastguard Worker }
603*89c4ff92SAndroid Build Coastguard Worker // If synchronous, execute all constant layer workloads
604*89c4ff92SAndroid Build Coastguard Worker if (!networkProperties.m_AsyncEnabled)
605*89c4ff92SAndroid Build Coastguard Worker {
606*89c4ff92SAndroid Build Coastguard Worker for (auto workload: ConstWorkloads)
607*89c4ff92SAndroid Build Coastguard Worker {
608*89c4ff92SAndroid Build Coastguard Worker workload->Execute();
609*89c4ff92SAndroid Build Coastguard Worker }
610*89c4ff92SAndroid Build Coastguard Worker }
611*89c4ff92SAndroid Build Coastguard Worker }
612*89c4ff92SAndroid Build Coastguard Worker
AllocateAndExecuteConstantWorkloads()613*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateAndExecuteConstantWorkloads()
614*89c4ff92SAndroid Build Coastguard Worker {
615*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_AllocateAndExecuteConstants");
616*89c4ff92SAndroid Build Coastguard Worker for (auto& pair : m_ConstantWorkloads)
617*89c4ff92SAndroid Build Coastguard Worker {
618*89c4ff92SAndroid Build Coastguard Worker auto tensorHandle = m_ConstantTensorHandles[pair.first];
619*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Allocate();
620*89c4ff92SAndroid Build Coastguard Worker pair.second->Execute();
621*89c4ff92SAndroid Build Coastguard Worker }
622*89c4ff92SAndroid Build Coastguard Worker }
623*89c4ff92SAndroid Build Coastguard Worker
AllocateAndExecuteConstantWorkloadsAsync()624*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateAndExecuteConstantWorkloadsAsync()
625*89c4ff92SAndroid Build Coastguard Worker {
626*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_AllocateAndExecuteConstants");
627*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
628*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
629*89c4ff92SAndroid Build Coastguard Worker {
630*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::Constant)
631*89c4ff92SAndroid Build Coastguard Worker {
632*89c4ff92SAndroid Build Coastguard Worker const auto& outSlot = layer->GetOutputSlots()[0];
633*89c4ff92SAndroid Build Coastguard Worker const auto factoryId = outSlot.GetTensorHandleFactoryId();
634*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(factoryId != ITensorHandleFactory::LegacyFactoryId);
635*89c4ff92SAndroid Build Coastguard Worker auto& workloadFactory = GetWorkloadFactory(*layer);
636*89c4ff92SAndroid Build Coastguard Worker
637*89c4ff92SAndroid Build Coastguard Worker layer->CreateTensorHandles(m_TensorHandleFactoryRegistry, workloadFactory);
638*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = outSlot.GetOutputHandler().GetData();
639*89c4ff92SAndroid Build Coastguard Worker
640*89c4ff92SAndroid Build Coastguard Worker m_ConstantTensorHandles[layer->GetGuid()] = tensorHandle;
641*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Allocate();
642*89c4ff92SAndroid Build Coastguard Worker
643*89c4ff92SAndroid Build Coastguard Worker auto& backend = m_Backends.at(layer->GetBackendId());
644*89c4ff92SAndroid Build Coastguard Worker
645*89c4ff92SAndroid Build Coastguard Worker WorkingMemDescriptor memDesc;
646*89c4ff92SAndroid Build Coastguard Worker memDesc.m_Outputs.push_back(tensorHandle);
647*89c4ff92SAndroid Build Coastguard Worker
648*89c4ff92SAndroid Build Coastguard Worker ExecutionData executionData = backend->CreateExecutionData(memDesc);
649*89c4ff92SAndroid Build Coastguard Worker m_ConstantWorkloads[layer->GetGuid()]->ExecuteAsync(executionData);
650*89c4ff92SAndroid Build Coastguard Worker }
651*89c4ff92SAndroid Build Coastguard Worker }
652*89c4ff92SAndroid Build Coastguard Worker }
653*89c4ff92SAndroid Build Coastguard Worker
SendNetworkStructure(arm::pipe::IProfilingService & profilingService)654*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::SendNetworkStructure(arm::pipe::IProfilingService& profilingService)
655*89c4ff92SAndroid Build Coastguard Worker {
656*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "LoadNetwork_SendNetworkStructure");
657*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
658*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
659*89c4ff92SAndroid Build Coastguard Worker
660*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
661*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(profilingService);
662*89c4ff92SAndroid Build Coastguard Worker
663*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateTypedEntity(networkGuid, LabelsAndEventClasses::NETWORK_GUID);
664*89c4ff92SAndroid Build Coastguard Worker
665*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
666*89c4ff92SAndroid Build Coastguard Worker {
667*89c4ff92SAndroid Build Coastguard Worker // Add layer to the post-optimisation network structure
668*89c4ff92SAndroid Build Coastguard Worker AddLayerStructure(timelineUtils, *layer, networkGuid);
669*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType())
670*89c4ff92SAndroid Build Coastguard Worker {
671*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
672*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
673*89c4ff92SAndroid Build Coastguard Worker {
674*89c4ff92SAndroid Build Coastguard Worker // Inputs and outputs are treated in a special way - see EnqueueInput() and EnqueueOutput().
675*89c4ff92SAndroid Build Coastguard Worker break;
676*89c4ff92SAndroid Build Coastguard Worker }
677*89c4ff92SAndroid Build Coastguard Worker default:
678*89c4ff92SAndroid Build Coastguard Worker {
679*89c4ff92SAndroid Build Coastguard Worker for (auto& workload : m_WorkloadQueue)
680*89c4ff92SAndroid Build Coastguard Worker {
681*89c4ff92SAndroid Build Coastguard Worker // Add workload to the post-optimisation network structure
682*89c4ff92SAndroid Build Coastguard Worker AddWorkloadStructure(timelineUtils, workload, *layer);
683*89c4ff92SAndroid Build Coastguard Worker }
684*89c4ff92SAndroid Build Coastguard Worker break;
685*89c4ff92SAndroid Build Coastguard Worker }
686*89c4ff92SAndroid Build Coastguard Worker }
687*89c4ff92SAndroid Build Coastguard Worker }
688*89c4ff92SAndroid Build Coastguard Worker // Commit to send the post-optimisation network structure
689*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
690*89c4ff92SAndroid Build Coastguard Worker }
691*89c4ff92SAndroid Build Coastguard Worker
GetNetworkGuid()692*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid LoadedNetwork::GetNetworkGuid()
693*89c4ff92SAndroid Build Coastguard Worker {
694*89c4ff92SAndroid Build Coastguard Worker return m_OptimizedNetwork->GetGuid();
695*89c4ff92SAndroid Build Coastguard Worker }
696*89c4ff92SAndroid Build Coastguard Worker
GetInputTensorInfo(LayerBindingId layerId) const697*89c4ff92SAndroid Build Coastguard Worker TensorInfo LoadedNetwork::GetInputTensorInfo(LayerBindingId layerId) const
698*89c4ff92SAndroid Build Coastguard Worker {
699*89c4ff92SAndroid Build Coastguard Worker for (auto&& inputLayer : m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetInputLayers())
700*89c4ff92SAndroid Build Coastguard Worker {
701*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(inputLayer->GetNumOutputSlots() == 1, "Input layer should have exactly 1 output slot");
702*89c4ff92SAndroid Build Coastguard Worker if (inputLayer->GetBindingId() == layerId)
703*89c4ff92SAndroid Build Coastguard Worker {
704*89c4ff92SAndroid Build Coastguard Worker return inputLayer->GetOutputSlot(0).GetTensorInfo();
705*89c4ff92SAndroid Build Coastguard Worker }
706*89c4ff92SAndroid Build Coastguard Worker }
707*89c4ff92SAndroid Build Coastguard Worker
708*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("No input layer is associated with id {}", layerId));
709*89c4ff92SAndroid Build Coastguard Worker }
710*89c4ff92SAndroid Build Coastguard Worker
GetOutputTensorInfo(LayerBindingId layerId) const711*89c4ff92SAndroid Build Coastguard Worker TensorInfo LoadedNetwork::GetOutputTensorInfo(LayerBindingId layerId) const
712*89c4ff92SAndroid Build Coastguard Worker {
713*89c4ff92SAndroid Build Coastguard Worker for (auto&& outputLayer : m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetOutputLayers())
714*89c4ff92SAndroid Build Coastguard Worker {
715*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(outputLayer->GetNumInputSlots() == 1, "Output layer should have exactly 1 input slot");
716*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(outputLayer->GetInputSlot(0).GetConnection(), "Input slot on Output layer must be connected");
717*89c4ff92SAndroid Build Coastguard Worker if (outputLayer->GetBindingId() == layerId)
718*89c4ff92SAndroid Build Coastguard Worker {
719*89c4ff92SAndroid Build Coastguard Worker return outputLayer->GetInputSlot(0).GetConnection()->GetTensorInfo();
720*89c4ff92SAndroid Build Coastguard Worker }
721*89c4ff92SAndroid Build Coastguard Worker }
722*89c4ff92SAndroid Build Coastguard Worker
723*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("No output layer is associated with id {}", layerId));
724*89c4ff92SAndroid Build Coastguard Worker }
725*89c4ff92SAndroid Build Coastguard Worker
GetWorkloadFactory(const Layer & layer) const726*89c4ff92SAndroid Build Coastguard Worker const IWorkloadFactory& LoadedNetwork::GetWorkloadFactory(const Layer& layer) const
727*89c4ff92SAndroid Build Coastguard Worker {
728*89c4ff92SAndroid Build Coastguard Worker const IWorkloadFactory* workloadFactory = nullptr;
729*89c4ff92SAndroid Build Coastguard Worker
730*89c4ff92SAndroid Build Coastguard Worker auto it = m_WorkloadFactories.find(layer.GetBackendId());
731*89c4ff92SAndroid Build Coastguard Worker if (it == m_WorkloadFactories.end())
732*89c4ff92SAndroid Build Coastguard Worker {
733*89c4ff92SAndroid Build Coastguard Worker throw RuntimeException(fmt::format("No workload factory for {0} to be used for layer: {1}",
734*89c4ff92SAndroid Build Coastguard Worker layer.GetBackendId().Get(),
735*89c4ff92SAndroid Build Coastguard Worker layer.GetNameStr()),
736*89c4ff92SAndroid Build Coastguard Worker CHECK_LOCATION());
737*89c4ff92SAndroid Build Coastguard Worker }
738*89c4ff92SAndroid Build Coastguard Worker
739*89c4ff92SAndroid Build Coastguard Worker workloadFactory = it->second.get();
740*89c4ff92SAndroid Build Coastguard Worker
741*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(workloadFactory, "No workload factory");
742*89c4ff92SAndroid Build Coastguard Worker
743*89c4ff92SAndroid Build Coastguard Worker return *workloadFactory;
744*89c4ff92SAndroid Build Coastguard Worker }
745*89c4ff92SAndroid Build Coastguard Worker
746*89c4ff92SAndroid Build Coastguard Worker namespace {
747*89c4ff92SAndroid Build Coastguard Worker
748*89c4ff92SAndroid Build Coastguard Worker // Non-copyable class owning accelerator-specific tensor data.
749*89c4ff92SAndroid Build Coastguard Worker class TensorPin
750*89c4ff92SAndroid Build Coastguard Worker {
751*89c4ff92SAndroid Build Coastguard Worker public:
TensorPin(std::unique_ptr<ITensorHandle> handle,const TensorInfo & info,LayerBindingId id)752*89c4ff92SAndroid Build Coastguard Worker TensorPin(std::unique_ptr<ITensorHandle> handle, const TensorInfo& info, LayerBindingId id)
753*89c4ff92SAndroid Build Coastguard Worker : m_TensorHandle(std::move(handle))
754*89c4ff92SAndroid Build Coastguard Worker , m_TensorInfo(info)
755*89c4ff92SAndroid Build Coastguard Worker , m_Id(id)
756*89c4ff92SAndroid Build Coastguard Worker {
757*89c4ff92SAndroid Build Coastguard Worker }
758*89c4ff92SAndroid Build Coastguard Worker
GetTensorHandle() const759*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* GetTensorHandle() const { return m_TensorHandle.get(); }
GetTensorInfo() const760*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
GetBindingId() const761*89c4ff92SAndroid Build Coastguard Worker LayerBindingId GetBindingId() const { return m_Id; }
762*89c4ff92SAndroid Build Coastguard Worker
763*89c4ff92SAndroid Build Coastguard Worker private:
764*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> m_TensorHandle;
765*89c4ff92SAndroid Build Coastguard Worker TensorInfo m_TensorInfo;
766*89c4ff92SAndroid Build Coastguard Worker LayerBindingId m_Id;
767*89c4ff92SAndroid Build Coastguard Worker };
768*89c4ff92SAndroid Build Coastguard Worker
GetTensorPin(LayerBindingId id,const std::vector<TensorPin> & pins,char const * bindingPointDesc)769*89c4ff92SAndroid Build Coastguard Worker static const TensorPin& GetTensorPin(LayerBindingId id,
770*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorPin>& pins,
771*89c4ff92SAndroid Build Coastguard Worker char const* bindingPointDesc)
772*89c4ff92SAndroid Build Coastguard Worker {
773*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(pins.begin(), pins.end(),
774*89c4ff92SAndroid Build Coastguard Worker [id](const TensorPin& pin)
775*89c4ff92SAndroid Build Coastguard Worker {
776*89c4ff92SAndroid Build Coastguard Worker return pin.GetBindingId() == id;
777*89c4ff92SAndroid Build Coastguard Worker });
778*89c4ff92SAndroid Build Coastguard Worker
779*89c4ff92SAndroid Build Coastguard Worker if (it != pins.end())
780*89c4ff92SAndroid Build Coastguard Worker {
781*89c4ff92SAndroid Build Coastguard Worker return *it;
782*89c4ff92SAndroid Build Coastguard Worker }
783*89c4ff92SAndroid Build Coastguard Worker else
784*89c4ff92SAndroid Build Coastguard Worker {
785*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("No tensor supplied for {0} {1}", bindingPointDesc, id));
786*89c4ff92SAndroid Build Coastguard Worker }
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker
789*89c4ff92SAndroid Build Coastguard Worker // Stores data that needs to be kept accessible for the entire execution of a workload.
790*89c4ff92SAndroid Build Coastguard Worker class WorkloadData
791*89c4ff92SAndroid Build Coastguard Worker {
792*89c4ff92SAndroid Build Coastguard Worker public:
WorkloadData(const InputTensors & inputTensors,const OutputTensors & outputTensors)793*89c4ff92SAndroid Build Coastguard Worker WorkloadData(const InputTensors& inputTensors, const OutputTensors& outputTensors)
794*89c4ff92SAndroid Build Coastguard Worker {
795*89c4ff92SAndroid Build Coastguard Worker m_InputTensorPins.reserve(inputTensors.size());
796*89c4ff92SAndroid Build Coastguard Worker m_OutputTensorPins.reserve(outputTensors.size());
797*89c4ff92SAndroid Build Coastguard Worker
798*89c4ff92SAndroid Build Coastguard Worker for (auto inputTensorPair : inputTensors)
799*89c4ff92SAndroid Build Coastguard Worker {
800*89c4ff92SAndroid Build Coastguard Worker auto inputTensor = inputTensorPair.second;
801*89c4ff92SAndroid Build Coastguard Worker
802*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
803*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(),inputTensor.GetMemoryArea());
804*89c4ff92SAndroid Build Coastguard Worker LayerBindingId layerId = inputTensorPair.first;
805*89c4ff92SAndroid Build Coastguard Worker
806*89c4ff92SAndroid Build Coastguard Worker m_InputTensorPins.emplace_back(std::move(tensorHandle), inputTensor.GetInfo(), layerId);
807*89c4ff92SAndroid Build Coastguard Worker }
808*89c4ff92SAndroid Build Coastguard Worker
809*89c4ff92SAndroid Build Coastguard Worker for (auto outputTensorPair : outputTensors)
810*89c4ff92SAndroid Build Coastguard Worker {
811*89c4ff92SAndroid Build Coastguard Worker auto outputTensor = outputTensorPair.second;
812*89c4ff92SAndroid Build Coastguard Worker
813*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
814*89c4ff92SAndroid Build Coastguard Worker std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(), outputTensor.GetMemoryArea());
815*89c4ff92SAndroid Build Coastguard Worker LayerBindingId layerId = outputTensorPair.first;
816*89c4ff92SAndroid Build Coastguard Worker
817*89c4ff92SAndroid Build Coastguard Worker m_OutputTensorPins.emplace_back(std::move(tensorHandle), outputTensor.GetInfo(), layerId);
818*89c4ff92SAndroid Build Coastguard Worker }
819*89c4ff92SAndroid Build Coastguard Worker }
820*89c4ff92SAndroid Build Coastguard Worker
GetInputTensorPin(LayerBindingId id) const821*89c4ff92SAndroid Build Coastguard Worker const TensorPin& GetInputTensorPin(LayerBindingId id) const
822*89c4ff92SAndroid Build Coastguard Worker {
823*89c4ff92SAndroid Build Coastguard Worker return GetTensorPin(id, m_InputTensorPins, "input");
824*89c4ff92SAndroid Build Coastguard Worker }
825*89c4ff92SAndroid Build Coastguard Worker
GetOutputTensorPin(LayerBindingId id) const826*89c4ff92SAndroid Build Coastguard Worker const TensorPin& GetOutputTensorPin(LayerBindingId id) const
827*89c4ff92SAndroid Build Coastguard Worker {
828*89c4ff92SAndroid Build Coastguard Worker return GetTensorPin(id, m_OutputTensorPins, "output");
829*89c4ff92SAndroid Build Coastguard Worker }
830*89c4ff92SAndroid Build Coastguard Worker
831*89c4ff92SAndroid Build Coastguard Worker private:
832*89c4ff92SAndroid Build Coastguard Worker
833*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorPin> m_InputTensorPins;
834*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorPin> m_OutputTensorPins;
835*89c4ff92SAndroid Build Coastguard Worker };
836*89c4ff92SAndroid Build Coastguard Worker
837*89c4ff92SAndroid Build Coastguard Worker }
838*89c4ff92SAndroid Build Coastguard Worker
EnqueueWorkload(const InputTensors & inputTensors,const OutputTensors & outputTensors,std::vector<ImportedInputId> preImportedInputIds,std::vector<ImportedOutputId> preImportedOutputIds)839*89c4ff92SAndroid Build Coastguard Worker Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
840*89c4ff92SAndroid Build Coastguard Worker const OutputTensors& outputTensors,
841*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> preImportedInputIds,
842*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> preImportedOutputIds)
843*89c4ff92SAndroid Build Coastguard Worker {
844*89c4ff92SAndroid Build Coastguard Worker const Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
845*89c4ff92SAndroid Build Coastguard Worker
846*89c4ff92SAndroid Build Coastguard Worker // Walk graph to determine the order of execution.
847*89c4ff92SAndroid Build Coastguard Worker if (graph.GetNumLayers() < 2)
848*89c4ff92SAndroid Build Coastguard Worker {
849*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(warning) << "IRuntime::EnqueueWorkload()::Less than two nodes in graph";
850*89c4ff92SAndroid Build Coastguard Worker return Status::Failure;
851*89c4ff92SAndroid Build Coastguard Worker }
852*89c4ff92SAndroid Build Coastguard Worker
853*89c4ff92SAndroid Build Coastguard Worker // Data that must be kept alive for the entire execution of the workload.
854*89c4ff92SAndroid Build Coastguard Worker WorkloadData workloadData(inputTensors, outputTensors);
855*89c4ff92SAndroid Build Coastguard Worker
856*89c4ff92SAndroid Build Coastguard Worker // Input tensors can be provided as parameters or pre imported. Either way the number of
857*89c4ff92SAndroid Build Coastguard Worker // tensors should match the number of inputs.
858*89c4ff92SAndroid Build Coastguard Worker if (graph.GetNumInputs() != (inputTensors.size() + preImportedInputIds.size()))
859*89c4ff92SAndroid Build Coastguard Worker {
860*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Number of inputs provided does not match network.");
861*89c4ff92SAndroid Build Coastguard Worker }
862*89c4ff92SAndroid Build Coastguard Worker
863*89c4ff92SAndroid Build Coastguard Worker // For each input to the network, call EnqueueInput with the data passed by the user.
864*89c4ff92SAndroid Build Coastguard Worker {
865*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs");
866*89c4ff92SAndroid Build Coastguard Worker m_InputQueue.clear();
867*89c4ff92SAndroid Build Coastguard Worker m_InputQueue.reserve(graph.GetNumInputs());
868*89c4ff92SAndroid Build Coastguard Worker
869*89c4ff92SAndroid Build Coastguard Worker unsigned int inputIndex = 0;
870*89c4ff92SAndroid Build Coastguard Worker unsigned int importedInputIdIndex = 0;
871*89c4ff92SAndroid Build Coastguard Worker std::sort(preImportedInputIds.begin(), preImportedInputIds.end());
872*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* inputLayer : graph.GetInputLayers())
873*89c4ff92SAndroid Build Coastguard Worker {
874*89c4ff92SAndroid Build Coastguard Worker if (importedInputIdIndex < preImportedInputIds.size() &&
875*89c4ff92SAndroid Build Coastguard Worker inputIndex == preImportedInputIds[importedInputIdIndex])
876*89c4ff92SAndroid Build Coastguard Worker {
877*89c4ff92SAndroid Build Coastguard Worker // Only replace tensorhandles if they have not already been replaced
878*89c4ff92SAndroid Build Coastguard Worker if (!m_IsInputImported[inputIndex])
879*89c4ff92SAndroid Build Coastguard Worker {
880*89c4ff92SAndroid Build Coastguard Worker auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
881*89c4ff92SAndroid Build Coastguard Worker
882*89c4ff92SAndroid Build Coastguard Worker for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
883*89c4ff92SAndroid Build Coastguard Worker {
884*89c4ff92SAndroid Build Coastguard Worker auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
885*89c4ff92SAndroid Build Coastguard Worker workload->ReplaceInputTensorHandle(outputTensorHandle, workloadInfo.m_SlotIndex);
886*89c4ff92SAndroid Build Coastguard Worker }
887*89c4ff92SAndroid Build Coastguard Worker m_IsInputImported[inputIndex] = true;
888*89c4ff92SAndroid Build Coastguard Worker }
889*89c4ff92SAndroid Build Coastguard Worker importedInputIdIndex++;
890*89c4ff92SAndroid Build Coastguard Worker }
891*89c4ff92SAndroid Build Coastguard Worker else
892*89c4ff92SAndroid Build Coastguard Worker {
893*89c4ff92SAndroid Build Coastguard Worker if (m_IsInputImported[inputIndex])
894*89c4ff92SAndroid Build Coastguard Worker {
895*89c4ff92SAndroid Build Coastguard Worker OutputHandler& handler = const_cast<OutputHandler&>(inputLayer->GetOutputHandler(0));
896*89c4ff92SAndroid Build Coastguard Worker
897*89c4ff92SAndroid Build Coastguard Worker for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
898*89c4ff92SAndroid Build Coastguard Worker {
899*89c4ff92SAndroid Build Coastguard Worker auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
900*89c4ff92SAndroid Build Coastguard Worker workload->ReplaceInputTensorHandle(handler.GetData(), workloadInfo.m_SlotIndex);
901*89c4ff92SAndroid Build Coastguard Worker }
902*89c4ff92SAndroid Build Coastguard Worker
903*89c4ff92SAndroid Build Coastguard Worker m_IsInputImported[inputIndex] = false;
904*89c4ff92SAndroid Build Coastguard Worker }
905*89c4ff92SAndroid Build Coastguard Worker
906*89c4ff92SAndroid Build Coastguard Worker // InputTensorHandle is not imported yet, process to enqueue input
907*89c4ff92SAndroid Build Coastguard Worker const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId());
908*89c4ff92SAndroid Build Coastguard Worker EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
909*89c4ff92SAndroid Build Coastguard Worker }
910*89c4ff92SAndroid Build Coastguard Worker inputIndex++;
911*89c4ff92SAndroid Build Coastguard Worker }
912*89c4ff92SAndroid Build Coastguard Worker }
913*89c4ff92SAndroid Build Coastguard Worker // For each output to the network, call EnqueueOutput with the data passed by the user.
914*89c4ff92SAndroid Build Coastguard Worker {
915*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs");
916*89c4ff92SAndroid Build Coastguard Worker m_OutputQueue.clear();
917*89c4ff92SAndroid Build Coastguard Worker m_OutputQueue.reserve(graph.GetNumOutputs());
918*89c4ff92SAndroid Build Coastguard Worker
919*89c4ff92SAndroid Build Coastguard Worker if (preImportedOutputIds.size() > graph.GetNumOutputs())
920*89c4ff92SAndroid Build Coastguard Worker {
921*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of preImportedOutputIds");
922*89c4ff92SAndroid Build Coastguard Worker }
923*89c4ff92SAndroid Build Coastguard Worker
924*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = 0;
925*89c4ff92SAndroid Build Coastguard Worker unsigned int importedOutputIdIndex = 0;
926*89c4ff92SAndroid Build Coastguard Worker std::sort(preImportedOutputIds.begin(), preImportedOutputIds.end());
927*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* outputLayer : graph.GetOutputLayers())
928*89c4ff92SAndroid Build Coastguard Worker {
929*89c4ff92SAndroid Build Coastguard Worker if (importedOutputIdIndex < preImportedOutputIds.size() &&
930*89c4ff92SAndroid Build Coastguard Worker outputIndex == preImportedOutputIds[importedOutputIdIndex])
931*89c4ff92SAndroid Build Coastguard Worker {
932*89c4ff92SAndroid Build Coastguard Worker // Only replace tensorhandles if they have not already been replaced
933*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
934*89c4ff92SAndroid Build Coastguard Worker
935*89c4ff92SAndroid Build Coastguard Worker if (!m_IsOutputImported[outputIndex])
936*89c4ff92SAndroid Build Coastguard Worker {
937*89c4ff92SAndroid Build Coastguard Worker const auto bindingId = outputLayer->GetBindingId();
938*89c4ff92SAndroid Build Coastguard Worker const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
939*89c4ff92SAndroid Build Coastguard Worker
940*89c4ff92SAndroid Build Coastguard Worker auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
941*89c4ff92SAndroid Build Coastguard Worker
942*89c4ff92SAndroid Build Coastguard Worker outputWorkload->ReplaceOutputTensorHandle(inputTensorHandle,
943*89c4ff92SAndroid Build Coastguard Worker indices.m_OutputSlotIndices.m_SlotIndex);
944*89c4ff92SAndroid Build Coastguard Worker
945*89c4ff92SAndroid Build Coastguard Worker for (const auto& workloadInfo: indices.m_InputSlotIndices)
946*89c4ff92SAndroid Build Coastguard Worker {
947*89c4ff92SAndroid Build Coastguard Worker auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
948*89c4ff92SAndroid Build Coastguard Worker inputWorkload->ReplaceInputTensorHandle(inputTensorHandle, workloadInfo.m_SlotIndex);
949*89c4ff92SAndroid Build Coastguard Worker }
950*89c4ff92SAndroid Build Coastguard Worker m_IsOutputImported[outputIndex] = true;
951*89c4ff92SAndroid Build Coastguard Worker }
952*89c4ff92SAndroid Build Coastguard Worker
953*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated.");
954*89c4ff92SAndroid Build Coastguard Worker MemSyncQueueDescriptor syncDesc;
955*89c4ff92SAndroid Build Coastguard Worker syncDesc.m_Inputs.push_back(inputTensorHandle);
956*89c4ff92SAndroid Build Coastguard Worker WorkloadInfo info;
957*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos.push_back(
958*89c4ff92SAndroid Build Coastguard Worker outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo());
959*89c4ff92SAndroid Build Coastguard Worker auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
960*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
961*89c4ff92SAndroid Build Coastguard Worker m_OutputQueue.push_back(move(syncWorkload));
962*89c4ff92SAndroid Build Coastguard Worker importedOutputIdIndex++;
963*89c4ff92SAndroid Build Coastguard Worker }
964*89c4ff92SAndroid Build Coastguard Worker else
965*89c4ff92SAndroid Build Coastguard Worker {
966*89c4ff92SAndroid Build Coastguard Worker if (m_IsOutputImported[outputIndex])
967*89c4ff92SAndroid Build Coastguard Worker {
968*89c4ff92SAndroid Build Coastguard Worker const auto bindingId = outputLayer->GetBindingId();
969*89c4ff92SAndroid Build Coastguard Worker const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
970*89c4ff92SAndroid Build Coastguard Worker
971*89c4ff92SAndroid Build Coastguard Worker auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
972*89c4ff92SAndroid Build Coastguard Worker const OutputHandler& outputHandler =
973*89c4ff92SAndroid Build Coastguard Worker outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOutputHandler();
974*89c4ff92SAndroid Build Coastguard Worker
975*89c4ff92SAndroid Build Coastguard Worker outputWorkload->ReplaceOutputTensorHandle(
976*89c4ff92SAndroid Build Coastguard Worker outputHandler.GetData(), indices.m_OutputSlotIndices.m_SlotIndex);
977*89c4ff92SAndroid Build Coastguard Worker
978*89c4ff92SAndroid Build Coastguard Worker for (const auto& workloadInfo: indices.m_InputSlotIndices)
979*89c4ff92SAndroid Build Coastguard Worker {
980*89c4ff92SAndroid Build Coastguard Worker auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
981*89c4ff92SAndroid Build Coastguard Worker inputWorkload->ReplaceInputTensorHandle(outputHandler.GetData(), workloadInfo.m_SlotIndex);
982*89c4ff92SAndroid Build Coastguard Worker }
983*89c4ff92SAndroid Build Coastguard Worker m_IsOutputImported[outputIndex] = false;
984*89c4ff92SAndroid Build Coastguard Worker }
985*89c4ff92SAndroid Build Coastguard Worker
986*89c4ff92SAndroid Build Coastguard Worker const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId());
987*89c4ff92SAndroid Build Coastguard Worker // OutputTensorHandle is not imported yet, process to enqueue Output
988*89c4ff92SAndroid Build Coastguard Worker EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
989*89c4ff92SAndroid Build Coastguard Worker }
990*89c4ff92SAndroid Build Coastguard Worker outputIndex++;
991*89c4ff92SAndroid Build Coastguard Worker }
992*89c4ff92SAndroid Build Coastguard Worker }
993*89c4ff92SAndroid Build Coastguard Worker
994*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
995*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
996*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid inferenceGuid = m_ProfilingService->GetNextGuid();
997*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
998*89c4ff92SAndroid Build Coastguard Worker {
999*89c4ff92SAndroid Build Coastguard Worker // Add inference timeline trace if profiling is enabled.
1000*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
1001*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateTypedEntity(inferenceGuid, LabelsAndEventClasses::INFERENCE_GUID);
1002*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
1003*89c4ff92SAndroid Build Coastguard Worker networkGuid,
1004*89c4ff92SAndroid Build Coastguard Worker inferenceGuid,
1005*89c4ff92SAndroid Build Coastguard Worker LabelsAndEventClasses::EXECUTION_OF_GUID);
1006*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
1007*89c4ff92SAndroid Build Coastguard Worker }
1008*89c4ff92SAndroid Build Coastguard Worker
1009*89c4ff92SAndroid Build Coastguard Worker bool executionSucceeded = true;
1010*89c4ff92SAndroid Build Coastguard Worker
1011*89c4ff92SAndroid Build Coastguard Worker {
1012*89c4ff92SAndroid Build Coastguard Worker if (m_ProfilingService->IsProfilingEnabled())
1013*89c4ff92SAndroid Build Coastguard Worker {
1014*89c4ff92SAndroid Build Coastguard Worker m_ProfilingService->IncrementCounterValue(INFERENCES_RUN);
1015*89c4ff92SAndroid Build Coastguard Worker }
1016*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Execute");
1017*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_HEAP_PROFILING("Executing");
1018*89c4ff92SAndroid Build Coastguard Worker executionSucceeded = Execute(timelineUtils, inferenceGuid);
1019*89c4ff92SAndroid Build Coastguard Worker }
1020*89c4ff92SAndroid Build Coastguard Worker
1021*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1022*89c4ff92SAndroid Build Coastguard Worker {
1023*89c4ff92SAndroid Build Coastguard Worker // Add end of life of the inference timeline if profiling is enabled.
1024*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
1025*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
1026*89c4ff92SAndroid Build Coastguard Worker }
1027*89c4ff92SAndroid Build Coastguard Worker
1028*89c4ff92SAndroid Build Coastguard Worker return executionSucceeded ? Status::Success : Status::Failure;
1029*89c4ff92SAndroid Build Coastguard Worker }
1030*89c4ff92SAndroid Build Coastguard Worker
EnqueueInput(const BindableLayer & layer,ITensorHandle * tensorHandle,const TensorInfo & tensorInfo)1031*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo)
1032*89c4ff92SAndroid Build Coastguard Worker {
1033*89c4ff92SAndroid Build Coastguard Worker if (layer.GetType() != LayerType::Input)
1034*89c4ff92SAndroid Build Coastguard Worker {
1035*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("EnqueueInput: given layer not an InputLayer");
1036*89c4ff92SAndroid Build Coastguard Worker }
1037*89c4ff92SAndroid Build Coastguard Worker
1038*89c4ff92SAndroid Build Coastguard Worker if (tensorHandle == nullptr)
1039*89c4ff92SAndroid Build Coastguard Worker {
1040*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("EnqueueInput: tensorHandle must not be NULL");
1041*89c4ff92SAndroid Build Coastguard Worker }
1042*89c4ff92SAndroid Build Coastguard Worker
1043*89c4ff92SAndroid Build Coastguard Worker InputQueueDescriptor inputQueueDescriptor;
1044*89c4ff92SAndroid Build Coastguard Worker WorkloadInfo info;
1045*89c4ff92SAndroid Build Coastguard Worker
1046*89c4ff92SAndroid Build Coastguard Worker inputQueueDescriptor.m_Inputs.push_back(tensorHandle);
1047*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos.push_back(tensorInfo);
1048*89c4ff92SAndroid Build Coastguard Worker
1049*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(layer.GetNumOutputSlots() == 1, "Can only handle Input Layer with one output");
1050*89c4ff92SAndroid Build Coastguard Worker const OutputHandler& handler = layer.GetOutputHandler();
1051*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputTensorInfo = handler.GetTensorInfo();
1052*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* outputTensorHandle = handler.GetData();
1053*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(outputTensorHandle != nullptr,
1054*89c4ff92SAndroid Build Coastguard Worker "Data should have been allocated.");
1055*89c4ff92SAndroid Build Coastguard Worker inputQueueDescriptor.m_Outputs.push_back(outputTensorHandle);
1056*89c4ff92SAndroid Build Coastguard Worker info.m_OutputTensorInfos.push_back(outputTensorInfo);
1057*89c4ff92SAndroid Build Coastguard Worker
1058*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = outputTensorHandle->GetImportFlags();
1059*89c4ff92SAndroid Build Coastguard Worker bool needMemCopy = true;
1060*89c4ff92SAndroid Build Coastguard Worker if (m_NetworkProperties.m_ImportEnabled) // Try import the input tensor
1061*89c4ff92SAndroid Build Coastguard Worker {
1062*89c4ff92SAndroid Build Coastguard Worker if(CheckFlag(importFlags, m_NetworkProperties.m_InputSource))
1063*89c4ff92SAndroid Build Coastguard Worker {
1064*89c4ff92SAndroid Build Coastguard Worker needMemCopy = false;
1065*89c4ff92SAndroid Build Coastguard Worker // This assumes a CPU Tensor handle
1066*89c4ff92SAndroid Build Coastguard Worker void* mem = tensorHandle->Map(false);
1067*89c4ff92SAndroid Build Coastguard Worker if (outputTensorHandle->Import(mem, m_NetworkProperties.m_InputSource))
1068*89c4ff92SAndroid Build Coastguard Worker {
1069*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1070*89c4ff92SAndroid Build Coastguard Worker return; // No need for a workload since the import has been done.
1071*89c4ff92SAndroid Build Coastguard Worker }
1072*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1073*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("EnqueueInput: Memory Import failed");
1074*89c4ff92SAndroid Build Coastguard Worker }
1075*89c4ff92SAndroid Build Coastguard Worker }
1076*89c4ff92SAndroid Build Coastguard Worker if (needMemCopy)
1077*89c4ff92SAndroid Build Coastguard Worker {
1078*89c4ff92SAndroid Build Coastguard Worker // Create a mem copy workload for input since we did not import
1079*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> inputWorkload = std::make_unique<CopyMemGenericWorkload>(inputQueueDescriptor, info);
1080*89c4ff92SAndroid Build Coastguard Worker
1081*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(inputWorkload, "No input workload created");
1082*89c4ff92SAndroid Build Coastguard Worker
1083*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1084*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1085*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1086*89c4ff92SAndroid Build Coastguard Worker {
1087*89c4ff92SAndroid Build Coastguard Worker // Add Input Workload to the post-optimisation network structure
1088*89c4ff92SAndroid Build Coastguard Worker AddWorkloadStructure(timelineUtils, inputWorkload, layer);
1089*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
1090*89c4ff92SAndroid Build Coastguard Worker }
1091*89c4ff92SAndroid Build Coastguard Worker
1092*89c4ff92SAndroid Build Coastguard Worker m_InputQueue.push_back(move(inputWorkload));
1093*89c4ff92SAndroid Build Coastguard Worker }
1094*89c4ff92SAndroid Build Coastguard Worker }
1095*89c4ff92SAndroid Build Coastguard Worker
EnqueueOutput(const BindableLayer & layer,ITensorHandle * tensorHandle,const TensorInfo & tensorInfo)1096*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* tensorHandle, const TensorInfo& tensorInfo)
1097*89c4ff92SAndroid Build Coastguard Worker {
1098*89c4ff92SAndroid Build Coastguard Worker if (layer.GetType() != LayerType::Output)
1099*89c4ff92SAndroid Build Coastguard Worker {
1100*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("EnqueueOutput: given layer not an OutputLayer");
1101*89c4ff92SAndroid Build Coastguard Worker }
1102*89c4ff92SAndroid Build Coastguard Worker
1103*89c4ff92SAndroid Build Coastguard Worker if (tensorHandle == nullptr)
1104*89c4ff92SAndroid Build Coastguard Worker {
1105*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("EnqueueOutput: tensorHandle must not be NULL");
1106*89c4ff92SAndroid Build Coastguard Worker }
1107*89c4ff92SAndroid Build Coastguard Worker
1108*89c4ff92SAndroid Build Coastguard Worker OutputQueueDescriptor outputQueueDescriptor;
1109*89c4ff92SAndroid Build Coastguard Worker WorkloadInfo info;
1110*89c4ff92SAndroid Build Coastguard Worker
1111*89c4ff92SAndroid Build Coastguard Worker outputQueueDescriptor.m_Outputs.push_back(tensorHandle);
1112*89c4ff92SAndroid Build Coastguard Worker info.m_OutputTensorInfos.push_back(tensorInfo);
1113*89c4ff92SAndroid Build Coastguard Worker
1114*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(layer.GetNumInputSlots() == 1, "Output Layer should have exactly one input.");
1115*89c4ff92SAndroid Build Coastguard Worker
1116*89c4ff92SAndroid Build Coastguard Worker // Gets the output handler from the previous node.
1117*89c4ff92SAndroid Build Coastguard Worker const OutputHandler& outputHandler = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
1118*89c4ff92SAndroid Build Coastguard Worker
1119*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputTensorInfo = outputHandler.GetTensorInfo();
1120*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* inputTensorHandle = outputHandler.GetData();
1121*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated.");
1122*89c4ff92SAndroid Build Coastguard Worker
1123*89c4ff92SAndroid Build Coastguard Worker // Try import the output tensor.
1124*89c4ff92SAndroid Build Coastguard Worker // Note: We can only import the output pointer if all of the following hold true:
1125*89c4ff92SAndroid Build Coastguard Worker // a) The imported pointer is aligned sufficiently
1126*89c4ff92SAndroid Build Coastguard Worker // b) The tensor has zero padding
1127*89c4ff92SAndroid Build Coastguard Worker // c) There is only one connection to the OutputSlot and it is to an OutputLayer.
1128*89c4ff92SAndroid Build Coastguard Worker // d) The output pointer is allocated via malloc. (Other types will be supported in a later release)
1129*89c4ff92SAndroid Build Coastguard Worker // e) m_IsExportEnabled must be set to true
1130*89c4ff92SAndroid Build Coastguard Worker bool needMemCopy = true;
1131*89c4ff92SAndroid Build Coastguard Worker if (m_NetworkProperties.m_ExportEnabled &&
1132*89c4ff92SAndroid Build Coastguard Worker (layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetNumConnections() == 1))
1133*89c4ff92SAndroid Build Coastguard Worker {
1134*89c4ff92SAndroid Build Coastguard Worker if(layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer().GetType() != LayerType::Input)
1135*89c4ff92SAndroid Build Coastguard Worker {
1136*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
1137*89c4ff92SAndroid Build Coastguard Worker if (CheckFlag(importFlags, m_NetworkProperties.m_OutputSource))
1138*89c4ff92SAndroid Build Coastguard Worker {
1139*89c4ff92SAndroid Build Coastguard Worker needMemCopy = false;
1140*89c4ff92SAndroid Build Coastguard Worker void *mem = tensorHandle->Map(false);
1141*89c4ff92SAndroid Build Coastguard Worker bool importOk = inputTensorHandle->Import(mem, m_NetworkProperties.m_OutputSource);
1142*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1143*89c4ff92SAndroid Build Coastguard Worker
1144*89c4ff92SAndroid Build Coastguard Worker if (importOk)
1145*89c4ff92SAndroid Build Coastguard Worker {
1146*89c4ff92SAndroid Build Coastguard Worker // Insert synchronization workload
1147*89c4ff92SAndroid Build Coastguard Worker MemSyncQueueDescriptor syncDesc;
1148*89c4ff92SAndroid Build Coastguard Worker syncDesc.m_Inputs.push_back(inputTensorHandle);
1149*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos.push_back(inputTensorInfo);
1150*89c4ff92SAndroid Build Coastguard Worker auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
1151*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
1152*89c4ff92SAndroid Build Coastguard Worker m_OutputQueue.push_back(move(syncWorkload));
1153*89c4ff92SAndroid Build Coastguard Worker }
1154*89c4ff92SAndroid Build Coastguard Worker else
1155*89c4ff92SAndroid Build Coastguard Worker {
1156*89c4ff92SAndroid Build Coastguard Worker throw MemoryExportException("EnqueueOutput: Memory Export failed");
1157*89c4ff92SAndroid Build Coastguard Worker }
1158*89c4ff92SAndroid Build Coastguard Worker }
1159*89c4ff92SAndroid Build Coastguard Worker }
1160*89c4ff92SAndroid Build Coastguard Worker }
1161*89c4ff92SAndroid Build Coastguard Worker if (needMemCopy)
1162*89c4ff92SAndroid Build Coastguard Worker {
1163*89c4ff92SAndroid Build Coastguard Worker // If we got here then we didn't export the memory, so add an output workload which performs a memcopy.
1164*89c4ff92SAndroid Build Coastguard Worker outputQueueDescriptor.m_Inputs.push_back(inputTensorHandle);
1165*89c4ff92SAndroid Build Coastguard Worker info.m_InputTensorInfos.push_back(inputTensorInfo);
1166*89c4ff92SAndroid Build Coastguard Worker
1167*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> outputWorkload =
1168*89c4ff92SAndroid Build Coastguard Worker std::make_unique<CopyMemGenericWorkload>(outputQueueDescriptor, info);
1169*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(outputWorkload, "No output workload created");
1170*89c4ff92SAndroid Build Coastguard Worker
1171*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1172*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1173*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1174*89c4ff92SAndroid Build Coastguard Worker {
1175*89c4ff92SAndroid Build Coastguard Worker // Add Output Workload to the post-optimisation network structure
1176*89c4ff92SAndroid Build Coastguard Worker AddWorkloadStructure(timelineUtils, outputWorkload, layer);
1177*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
1178*89c4ff92SAndroid Build Coastguard Worker }
1179*89c4ff92SAndroid Build Coastguard Worker
1180*89c4ff92SAndroid Build Coastguard Worker m_OutputQueue.push_back(move(outputWorkload));
1181*89c4ff92SAndroid Build Coastguard Worker }
1182*89c4ff92SAndroid Build Coastguard Worker }
1183*89c4ff92SAndroid Build Coastguard Worker
AllocateWorkingMemory(std::lock_guard<std::mutex> & lock)1184*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::AllocateWorkingMemory(
1185*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1186*89c4ff92SAndroid Build Coastguard Worker std::lock_guard<std::mutex>& lock
1187*89c4ff92SAndroid Build Coastguard Worker #endif
1188*89c4ff92SAndroid Build Coastguard Worker )
1189*89c4ff92SAndroid Build Coastguard Worker {
1190*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Working Memory Allocation");
1191*89c4ff92SAndroid Build Coastguard Worker
1192*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1193*89c4ff92SAndroid Build Coastguard Worker // this unused parameter makes sure we can only call this function with a valid lock
1194*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(lock);
1195*89c4ff92SAndroid Build Coastguard Worker #endif
1196*89c4ff92SAndroid Build Coastguard Worker if (m_IsWorkingMemAllocated)
1197*89c4ff92SAndroid Build Coastguard Worker {
1198*89c4ff92SAndroid Build Coastguard Worker return;
1199*89c4ff92SAndroid Build Coastguard Worker }
1200*89c4ff92SAndroid Build Coastguard Worker
1201*89c4ff92SAndroid Build Coastguard Worker if (m_ExternalMemoryManager)
1202*89c4ff92SAndroid Build Coastguard Worker {
1203*89c4ff92SAndroid Build Coastguard Worker m_ExternalMemoryManager->Allocate();
1204*89c4ff92SAndroid Build Coastguard Worker
1205*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_TensorMemory.size(); ++i)
1206*89c4ff92SAndroid Build Coastguard Worker {
1207*89c4ff92SAndroid Build Coastguard Worker m_Tensorhandles[i]->Import(m_TensorMemory[i].first->m_Data, m_TensorMemory[i].second);
1208*89c4ff92SAndroid Build Coastguard Worker }
1209*89c4ff92SAndroid Build Coastguard Worker }
1210*89c4ff92SAndroid Build Coastguard Worker
1211*89c4ff92SAndroid Build Coastguard Worker for (auto&& memoryManager : m_BackendMemoryMangers)
1212*89c4ff92SAndroid Build Coastguard Worker {
1213*89c4ff92SAndroid Build Coastguard Worker if (memoryManager)
1214*89c4ff92SAndroid Build Coastguard Worker {
1215*89c4ff92SAndroid Build Coastguard Worker memoryManager->Acquire();
1216*89c4ff92SAndroid Build Coastguard Worker }
1217*89c4ff92SAndroid Build Coastguard Worker }
1218*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry.AquireMemory();
1219*89c4ff92SAndroid Build Coastguard Worker m_IsWorkingMemAllocated = true;
1220*89c4ff92SAndroid Build Coastguard Worker }
1221*89c4ff92SAndroid Build Coastguard Worker
FreeWorkingMemory()1222*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::FreeWorkingMemory()
1223*89c4ff92SAndroid Build Coastguard Worker {
1224*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1225*89c4ff92SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lockGuard(m_WorkingMemMutex);
1226*89c4ff92SAndroid Build Coastguard Worker #endif
1227*89c4ff92SAndroid Build Coastguard Worker
1228*89c4ff92SAndroid Build Coastguard Worker if (!m_IsWorkingMemAllocated)
1229*89c4ff92SAndroid Build Coastguard Worker {
1230*89c4ff92SAndroid Build Coastguard Worker return;
1231*89c4ff92SAndroid Build Coastguard Worker }
1232*89c4ff92SAndroid Build Coastguard Worker
1233*89c4ff92SAndroid Build Coastguard Worker if (m_ExternalMemoryManager)
1234*89c4ff92SAndroid Build Coastguard Worker {
1235*89c4ff92SAndroid Build Coastguard Worker m_ExternalMemoryManager->Deallocate();
1236*89c4ff92SAndroid Build Coastguard Worker }
1237*89c4ff92SAndroid Build Coastguard Worker
1238*89c4ff92SAndroid Build Coastguard Worker // Informs the memory managers to release memory in its respective memory group
1239*89c4ff92SAndroid Build Coastguard Worker for (auto&& memoryManager : m_BackendMemoryMangers)
1240*89c4ff92SAndroid Build Coastguard Worker {
1241*89c4ff92SAndroid Build Coastguard Worker if (memoryManager)
1242*89c4ff92SAndroid Build Coastguard Worker {
1243*89c4ff92SAndroid Build Coastguard Worker memoryManager->Release();
1244*89c4ff92SAndroid Build Coastguard Worker }
1245*89c4ff92SAndroid Build Coastguard Worker }
1246*89c4ff92SAndroid Build Coastguard Worker m_TensorHandleFactoryRegistry.ReleaseMemory();
1247*89c4ff92SAndroid Build Coastguard Worker m_IsWorkingMemAllocated = false;
1248*89c4ff92SAndroid Build Coastguard Worker }
1249*89c4ff92SAndroid Build Coastguard Worker
Execute(std::unique_ptr<TimelineUtilityMethods> & timelineUtils,ProfilingGuid inferenceGuid)1250*89c4ff92SAndroid Build Coastguard Worker bool LoadedNetwork::Execute(std::unique_ptr<TimelineUtilityMethods>& timelineUtils,
1251*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid inferenceGuid)
1252*89c4ff92SAndroid Build Coastguard Worker {
1253*89c4ff92SAndroid Build Coastguard Worker bool success = true;
1254*89c4ff92SAndroid Build Coastguard Worker
1255*89c4ff92SAndroid Build Coastguard Worker auto Fail = [&](const std::exception& error)
1256*89c4ff92SAndroid Build Coastguard Worker {
1257*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "An error occurred attempting to execute a workload: " << error.what();
1258*89c4ff92SAndroid Build Coastguard Worker success = false;
1259*89c4ff92SAndroid Build Coastguard Worker };
1260*89c4ff92SAndroid Build Coastguard Worker
1261*89c4ff92SAndroid Build Coastguard Worker try
1262*89c4ff92SAndroid Build Coastguard Worker {
1263*89c4ff92SAndroid Build Coastguard Worker #if !defined(ARMNN_DISABLE_THREADS)
1264*89c4ff92SAndroid Build Coastguard Worker std::lock_guard<std::mutex> lockGuard(m_WorkingMemMutex);
1265*89c4ff92SAndroid Build Coastguard Worker AllocateWorkingMemory(lockGuard);
1266*89c4ff92SAndroid Build Coastguard Worker #else
1267*89c4ff92SAndroid Build Coastguard Worker AllocateWorkingMemory();
1268*89c4ff92SAndroid Build Coastguard Worker #endif
1269*89c4ff92SAndroid Build Coastguard Worker
1270*89c4ff92SAndroid Build Coastguard Worker ProfilingDynamicGuid workloadInferenceID(0);
1271*89c4ff92SAndroid Build Coastguard Worker auto ExecuteQueue = [&timelineUtils, &workloadInferenceID, &inferenceGuid](WorkloadQueue& queue)
1272*89c4ff92SAndroid Build Coastguard Worker {
1273*89c4ff92SAndroid Build Coastguard Worker for (auto& workload : queue)
1274*89c4ff92SAndroid Build Coastguard Worker {
1275*89c4ff92SAndroid Build Coastguard Worker if(timelineUtils)
1276*89c4ff92SAndroid Build Coastguard Worker {
1277*89c4ff92SAndroid Build Coastguard Worker workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(),
1278*89c4ff92SAndroid Build Coastguard Worker inferenceGuid);
1279*89c4ff92SAndroid Build Coastguard Worker }
1280*89c4ff92SAndroid Build Coastguard Worker workload->Execute();
1281*89c4ff92SAndroid Build Coastguard Worker if(timelineUtils)
1282*89c4ff92SAndroid Build Coastguard Worker {
1283*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEndOfLifeEvent(workloadInferenceID);
1284*89c4ff92SAndroid Build Coastguard Worker }
1285*89c4ff92SAndroid Build Coastguard Worker }
1286*89c4ff92SAndroid Build Coastguard Worker };
1287*89c4ff92SAndroid Build Coastguard Worker
1288*89c4ff92SAndroid Build Coastguard Worker ExecuteQueue(m_InputQueue);
1289*89c4ff92SAndroid Build Coastguard Worker ExecuteQueue(m_WorkloadQueue);
1290*89c4ff92SAndroid Build Coastguard Worker ExecuteQueue(m_OutputQueue);
1291*89c4ff92SAndroid Build Coastguard Worker }
1292*89c4ff92SAndroid Build Coastguard Worker catch (const RuntimeException& error)
1293*89c4ff92SAndroid Build Coastguard Worker {
1294*89c4ff92SAndroid Build Coastguard Worker Fail(error);
1295*89c4ff92SAndroid Build Coastguard Worker }
1296*89c4ff92SAndroid Build Coastguard Worker catch (const std::runtime_error& error)
1297*89c4ff92SAndroid Build Coastguard Worker {
1298*89c4ff92SAndroid Build Coastguard Worker Fail(error);
1299*89c4ff92SAndroid Build Coastguard Worker }
1300*89c4ff92SAndroid Build Coastguard Worker
1301*89c4ff92SAndroid Build Coastguard Worker return success;
1302*89c4ff92SAndroid Build Coastguard Worker }
1303*89c4ff92SAndroid Build Coastguard Worker
EnqueueInput(const ConstTensor & inputTensor,ITensorHandle * inputTensorHandle)1304*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::EnqueueInput(const ConstTensor& inputTensor, ITensorHandle* inputTensorHandle)
1305*89c4ff92SAndroid Build Coastguard Worker {
1306*89c4ff92SAndroid Build Coastguard Worker if (m_NetworkProperties.m_ImportEnabled) // Try import the input tensor
1307*89c4ff92SAndroid Build Coastguard Worker {
1308*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = inputTensorHandle->GetImportFlags();
1309*89c4ff92SAndroid Build Coastguard Worker if (CheckFlag(importFlags, m_NetworkProperties.m_InputSource) )
1310*89c4ff92SAndroid Build Coastguard Worker {
1311*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
1312*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(),
1313*89c4ff92SAndroid Build Coastguard Worker inputTensor.GetMemoryArea());
1314*89c4ff92SAndroid Build Coastguard Worker void* mem = tensorHandle->Map(false);
1315*89c4ff92SAndroid Build Coastguard Worker
1316*89c4ff92SAndroid Build Coastguard Worker if (inputTensorHandle->Import(mem, m_NetworkProperties.m_InputSource))
1317*89c4ff92SAndroid Build Coastguard Worker {
1318*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1319*89c4ff92SAndroid Build Coastguard Worker return;
1320*89c4ff92SAndroid Build Coastguard Worker }
1321*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1322*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("EnqueueInput: Memory Import failed");
1323*89c4ff92SAndroid Build Coastguard Worker }
1324*89c4ff92SAndroid Build Coastguard Worker else
1325*89c4ff92SAndroid Build Coastguard Worker {
1326*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("EnqueueInput: Memory Import failed, backend does not support Import");
1327*89c4ff92SAndroid Build Coastguard Worker }
1328*89c4ff92SAndroid Build Coastguard Worker }
1329*89c4ff92SAndroid Build Coastguard Worker else
1330*89c4ff92SAndroid Build Coastguard Worker {
1331*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyInput");
1332*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
1333*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ConstPassthroughTensorHandle>(inputTensor.GetInfo(), inputTensor.GetMemoryArea());
1334*89c4ff92SAndroid Build Coastguard Worker
1335*89c4ff92SAndroid Build Coastguard Worker auto copyFunc = [](void* dst, const void* src, size_t size)
1336*89c4ff92SAndroid Build Coastguard Worker {
1337*89c4ff92SAndroid Build Coastguard Worker memcpy(dst, src, size);
1338*89c4ff92SAndroid Build Coastguard Worker };
1339*89c4ff92SAndroid Build Coastguard Worker
1340*89c4ff92SAndroid Build Coastguard Worker CopyTensorContentsGeneric(tensorHandle.get(), inputTensorHandle, copyFunc);
1341*89c4ff92SAndroid Build Coastguard Worker }
1342*89c4ff92SAndroid Build Coastguard Worker }
1343*89c4ff92SAndroid Build Coastguard Worker
1344*89c4ff92SAndroid Build Coastguard Worker // Note: We can only import the output pointer if all of the following hold true:
1345*89c4ff92SAndroid Build Coastguard Worker // a) The imported pointer is aligned sufficiently
1346*89c4ff92SAndroid Build Coastguard Worker // b) The tensor has zero padding
1347*89c4ff92SAndroid Build Coastguard Worker // c) There is only one connection to the OutputSlot and it is to an OutputLayer.
1348*89c4ff92SAndroid Build Coastguard Worker // d) The output pointer is allocated via malloc. (Other types will be supported in a later release)
1349*89c4ff92SAndroid Build Coastguard Worker // e) m_IsExportEnabled must be set to true
ImportOutputTensor(const Tensor & outputTensor,ITensorHandle * outputTensorHandle)1350*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ImportOutputTensor(const Tensor& outputTensor, ITensorHandle* outputTensorHandle)
1351*89c4ff92SAndroid Build Coastguard Worker {
1352*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT_MSG(outputTensorHandle != nullptr, "Data should have been allocated.");
1353*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags importFlags = outputTensorHandle->GetImportFlags();
1354*89c4ff92SAndroid Build Coastguard Worker if (CheckFlag(importFlags, m_NetworkProperties.m_OutputSource))
1355*89c4ff92SAndroid Build Coastguard Worker {
1356*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
1357*89c4ff92SAndroid Build Coastguard Worker std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(),
1358*89c4ff92SAndroid Build Coastguard Worker outputTensor.GetMemoryArea());
1359*89c4ff92SAndroid Build Coastguard Worker
1360*89c4ff92SAndroid Build Coastguard Worker void* mem = tensorHandle->Map(false);
1361*89c4ff92SAndroid Build Coastguard Worker bool importOk = outputTensorHandle->Import(mem, m_NetworkProperties.m_OutputSource);
1362*89c4ff92SAndroid Build Coastguard Worker tensorHandle->Unmap();
1363*89c4ff92SAndroid Build Coastguard Worker
1364*89c4ff92SAndroid Build Coastguard Worker if (!importOk)
1365*89c4ff92SAndroid Build Coastguard Worker {
1366*89c4ff92SAndroid Build Coastguard Worker throw MemoryExportException("ImportOutputTensor: Memory Export failed");
1367*89c4ff92SAndroid Build Coastguard Worker }
1368*89c4ff92SAndroid Build Coastguard Worker }
1369*89c4ff92SAndroid Build Coastguard Worker else
1370*89c4ff92SAndroid Build Coastguard Worker {
1371*89c4ff92SAndroid Build Coastguard Worker throw MemoryExportException("ImportOutputTensor: Memory Export failed, attempting to export Input Layer");
1372*89c4ff92SAndroid Build Coastguard Worker }
1373*89c4ff92SAndroid Build Coastguard Worker
1374*89c4ff92SAndroid Build Coastguard Worker }
1375*89c4ff92SAndroid Build Coastguard Worker
CopyToOutputTensor(const Tensor & outputTensor,ITensorHandle * outputTensorHandle)1376*89c4ff92SAndroid Build Coastguard Worker void CopyToOutputTensor(const Tensor& outputTensor, ITensorHandle* outputTensorHandle)
1377*89c4ff92SAndroid Build Coastguard Worker {
1378*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyOutput");
1379*89c4ff92SAndroid Build Coastguard Worker auto copyFunc = [](void* dst, const void* src, size_t size)
1380*89c4ff92SAndroid Build Coastguard Worker {
1381*89c4ff92SAndroid Build Coastguard Worker memcpy(dst, src, size);
1382*89c4ff92SAndroid Build Coastguard Worker };
1383*89c4ff92SAndroid Build Coastguard Worker
1384*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> tensorHandle =
1385*89c4ff92SAndroid Build Coastguard Worker std::make_unique<PassthroughTensorHandle>(outputTensor.GetInfo(),
1386*89c4ff92SAndroid Build Coastguard Worker outputTensor.GetMemoryArea());
1387*89c4ff92SAndroid Build Coastguard Worker
1388*89c4ff92SAndroid Build Coastguard Worker CopyTensorContentsGeneric(outputTensorHandle, tensorHandle.get(), copyFunc);
1389*89c4ff92SAndroid Build Coastguard Worker }
1390*89c4ff92SAndroid Build Coastguard Worker
1391*89c4ff92SAndroid Build Coastguard Worker
GetInputTensor(const LayerBindingId layerId,const InputTensors & inputTensors)1392*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor GetInputTensor(const LayerBindingId layerId, const InputTensors& inputTensors)
1393*89c4ff92SAndroid Build Coastguard Worker {
1394*89c4ff92SAndroid Build Coastguard Worker for (auto inputTensorPair : inputTensors)
1395*89c4ff92SAndroid Build Coastguard Worker {
1396*89c4ff92SAndroid Build Coastguard Worker LayerBindingId id = inputTensorPair.first;
1397*89c4ff92SAndroid Build Coastguard Worker if (id == layerId)
1398*89c4ff92SAndroid Build Coastguard Worker {
1399*89c4ff92SAndroid Build Coastguard Worker return inputTensorPair.second;
1400*89c4ff92SAndroid Build Coastguard Worker }
1401*89c4ff92SAndroid Build Coastguard Worker }
1402*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Input does not exist.");
1403*89c4ff92SAndroid Build Coastguard Worker }
1404*89c4ff92SAndroid Build Coastguard Worker
GetOutputTensor(const LayerBindingId layerId,const OutputTensors & outputTensors)1405*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor GetOutputTensor(const LayerBindingId layerId, const OutputTensors& outputTensors)
1406*89c4ff92SAndroid Build Coastguard Worker {
1407*89c4ff92SAndroid Build Coastguard Worker for (auto outputTensorPair : outputTensors)
1408*89c4ff92SAndroid Build Coastguard Worker {
1409*89c4ff92SAndroid Build Coastguard Worker LayerBindingId id = outputTensorPair.first;
1410*89c4ff92SAndroid Build Coastguard Worker if (id == layerId)
1411*89c4ff92SAndroid Build Coastguard Worker {
1412*89c4ff92SAndroid Build Coastguard Worker return outputTensorPair.second;
1413*89c4ff92SAndroid Build Coastguard Worker }
1414*89c4ff92SAndroid Build Coastguard Worker }
1415*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Output does not exist.");
1416*89c4ff92SAndroid Build Coastguard Worker }
1417*89c4ff92SAndroid Build Coastguard Worker
ImportInputs(const InputTensors & inputTensors,MemorySource forceImportMemorySource)1418*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inputTensors,
1419*89c4ff92SAndroid Build Coastguard Worker MemorySource forceImportMemorySource)
1420*89c4ff92SAndroid Build Coastguard Worker {
1421*89c4ff92SAndroid Build Coastguard Worker if (!m_NetworkProperties.m_AsyncEnabled)
1422*89c4ff92SAndroid Build Coastguard Worker {
1423*89c4ff92SAndroid Build Coastguard Worker // Cannot import if import is not enabled and forceImportMemorySource is undefined
1424*89c4ff92SAndroid Build Coastguard Worker if (forceImportMemorySource == MemorySource::Undefined)
1425*89c4ff92SAndroid Build Coastguard Worker {
1426*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportInputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
1427*89c4ff92SAndroid Build Coastguard Worker }
1428*89c4ff92SAndroid Build Coastguard Worker // The number of pre imported tensors should not exceed the number of inputs.
1429*89c4ff92SAndroid Build Coastguard Worker if (inputTensors.size() > m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumInputs())
1430*89c4ff92SAndroid Build Coastguard Worker {
1431*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportInputs: The number of tensors provided exceeds the number of inputs.");
1432*89c4ff92SAndroid Build Coastguard Worker }
1433*89c4ff92SAndroid Build Coastguard Worker
1434*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> importedInputs;
1435*89c4ff92SAndroid Build Coastguard Worker Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1436*89c4ff92SAndroid Build Coastguard Worker unsigned int inputIndex = 0;
1437*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* inputLayer : graph.GetInputLayers())
1438*89c4ff92SAndroid Build Coastguard Worker {
1439*89c4ff92SAndroid Build Coastguard Worker auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
1440*89c4ff92SAndroid Build Coastguard Worker
1441*89c4ff92SAndroid Build Coastguard Worker if (!outputTensorHandle)
1442*89c4ff92SAndroid Build Coastguard Worker {
1443*89c4ff92SAndroid Build Coastguard Worker inputIndex++;
1444*89c4ff92SAndroid Build Coastguard Worker continue;
1445*89c4ff92SAndroid Build Coastguard Worker }
1446*89c4ff92SAndroid Build Coastguard Worker
1447*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = inputLayer->GetBindingId();
1448*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(inputTensors.begin(), inputTensors.end(), [=](const auto& inputTensor)
1449*89c4ff92SAndroid Build Coastguard Worker {
1450*89c4ff92SAndroid Build Coastguard Worker return inputTensor.first == layerBindingId;
1451*89c4ff92SAndroid Build Coastguard Worker });
1452*89c4ff92SAndroid Build Coastguard Worker
1453*89c4ff92SAndroid Build Coastguard Worker if (it == inputTensors.end())
1454*89c4ff92SAndroid Build Coastguard Worker {
1455*89c4ff92SAndroid Build Coastguard Worker inputIndex++;
1456*89c4ff92SAndroid Build Coastguard Worker continue;
1457*89c4ff92SAndroid Build Coastguard Worker }
1458*89c4ff92SAndroid Build Coastguard Worker
1459*89c4ff92SAndroid Build Coastguard Worker const auto& inputTensor = *it;
1460*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> passThroughTensorHandle =
1461*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
1462*89c4ff92SAndroid Build Coastguard Worker inputTensor.second.GetMemoryArea());
1463*89c4ff92SAndroid Build Coastguard Worker
1464*89c4ff92SAndroid Build Coastguard Worker try
1465*89c4ff92SAndroid Build Coastguard Worker {
1466*89c4ff92SAndroid Build Coastguard Worker if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource)
1467*89c4ff92SAndroid Build Coastguard Worker && (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource)))
1468*89c4ff92SAndroid Build Coastguard Worker {
1469*89c4ff92SAndroid Build Coastguard Worker importedInputs.push_back(inputIndex);
1470*89c4ff92SAndroid Build Coastguard Worker }
1471*89c4ff92SAndroid Build Coastguard Worker passThroughTensorHandle->Unmap();
1472*89c4ff92SAndroid Build Coastguard Worker }
1473*89c4ff92SAndroid Build Coastguard Worker catch(const MemoryImportException& exception)
1474*89c4ff92SAndroid Build Coastguard Worker {
1475*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "An error occurred attempting to import input_"
1476*89c4ff92SAndroid Build Coastguard Worker << inputIndex << " : " << exception.what();
1477*89c4ff92SAndroid Build Coastguard Worker passThroughTensorHandle->Unmap();
1478*89c4ff92SAndroid Build Coastguard Worker }
1479*89c4ff92SAndroid Build Coastguard Worker inputIndex++;
1480*89c4ff92SAndroid Build Coastguard Worker }
1481*89c4ff92SAndroid Build Coastguard Worker
1482*89c4ff92SAndroid Build Coastguard Worker return importedInputs;
1483*89c4ff92SAndroid Build Coastguard Worker }
1484*89c4ff92SAndroid Build Coastguard Worker else
1485*89c4ff92SAndroid Build Coastguard Worker {
1486*89c4ff92SAndroid Build Coastguard Worker // Import when the import of network properties is enabled
1487*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> importedInputs;
1488*89c4ff92SAndroid Build Coastguard Worker Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1489*89c4ff92SAndroid Build Coastguard Worker
1490*89c4ff92SAndroid Build Coastguard Worker for (auto inputTensor : inputTensors)
1491*89c4ff92SAndroid Build Coastguard Worker {
1492*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = inputTensor.first;
1493*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(graph.GetInputLayers().begin(), graph.GetInputLayers().end(), [=](auto* layer)
1494*89c4ff92SAndroid Build Coastguard Worker {
1495*89c4ff92SAndroid Build Coastguard Worker return layer->GetBindingId() == layerBindingId;
1496*89c4ff92SAndroid Build Coastguard Worker });
1497*89c4ff92SAndroid Build Coastguard Worker
1498*89c4ff92SAndroid Build Coastguard Worker if (it == graph.GetInputLayers().end())
1499*89c4ff92SAndroid Build Coastguard Worker {
1500*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(fmt::format(
1501*89c4ff92SAndroid Build Coastguard Worker "ImportInputs: Memory Import failed, unknown LayerBindingId: {}", layerBindingId));
1502*89c4ff92SAndroid Build Coastguard Worker }
1503*89c4ff92SAndroid Build Coastguard Worker
1504*89c4ff92SAndroid Build Coastguard Worker const Layer* layer = *it;
1505*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() != LayerType::Input)
1506*89c4ff92SAndroid Build Coastguard Worker {
1507*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("ImportInputs: given layer not an InputLayer");
1508*89c4ff92SAndroid Build Coastguard Worker }
1509*89c4ff92SAndroid Build Coastguard Worker
1510*89c4ff92SAndroid Build Coastguard Worker auto& backend = m_Backends.at(layer->GetBackendId());
1511*89c4ff92SAndroid Build Coastguard Worker if (!HasCapability(BackendOptions::BackendOption{"PreImportIOTensors", true}, backend->GetCapabilities()))
1512*89c4ff92SAndroid Build Coastguard Worker {
1513*89c4ff92SAndroid Build Coastguard Worker std::string er = backend->GetId();
1514*89c4ff92SAndroid Build Coastguard Worker er += " does not have PreImportIOTensors capability";
1515*89c4ff92SAndroid Build Coastguard Worker throw BackendCapabilityException(er);
1516*89c4ff92SAndroid Build Coastguard Worker }
1517*89c4ff92SAndroid Build Coastguard Worker
1518*89c4ff92SAndroid Build Coastguard Worker const OutputSlot& outputSlot = layer->GetOutputSlots()[0];
1519*89c4ff92SAndroid Build Coastguard Worker
1520*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
1521*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1522*89c4ff92SAndroid Build Coastguard Worker
1523*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1524*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(handleFactory);
1525*89c4ff92SAndroid Build Coastguard Worker
1526*89c4ff92SAndroid Build Coastguard Worker ImportedTensorHandlePin importedTensorHandlePin{layerBindingId,
1527*89c4ff92SAndroid Build Coastguard Worker handleFactory->CreateTensorHandle(tensorInfo, false)};
1528*89c4ff92SAndroid Build Coastguard Worker
1529*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
1530*89c4ff92SAndroid Build Coastguard Worker
1531*89c4ff92SAndroid Build Coastguard Worker if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
1532*89c4ff92SAndroid Build Coastguard Worker {
1533*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(
1534*89c4ff92SAndroid Build Coastguard Worker fmt::format("ImportInputs: Memory Import failed, backend: "
1535*89c4ff92SAndroid Build Coastguard Worker "{} does not support importing from source {}"
1536*89c4ff92SAndroid Build Coastguard Worker , factoryId, m_NetworkProperties.m_InputSource));
1537*89c4ff92SAndroid Build Coastguard Worker }
1538*89c4ff92SAndroid Build Coastguard Worker
1539*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> passThroughTensorHandle =
1540*89c4ff92SAndroid Build Coastguard Worker std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
1541*89c4ff92SAndroid Build Coastguard Worker inputTensor.second.GetMemoryArea());
1542*89c4ff92SAndroid Build Coastguard Worker
1543*89c4ff92SAndroid Build Coastguard Worker if (tensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))
1544*89c4ff92SAndroid Build Coastguard Worker {
1545*89c4ff92SAndroid Build Coastguard Worker importedInputs.push_back(m_CurImportedInputId++);
1546*89c4ff92SAndroid Build Coastguard Worker passThroughTensorHandle->Unmap();
1547*89c4ff92SAndroid Build Coastguard Worker }
1548*89c4ff92SAndroid Build Coastguard Worker else
1549*89c4ff92SAndroid Build Coastguard Worker {
1550*89c4ff92SAndroid Build Coastguard Worker passThroughTensorHandle->Unmap();
1551*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportInputs: Memory Import failed");
1552*89c4ff92SAndroid Build Coastguard Worker }
1553*89c4ff92SAndroid Build Coastguard Worker
1554*89c4ff92SAndroid Build Coastguard Worker m_PreImportedInputHandles.push_back(std::move(importedTensorHandlePin));
1555*89c4ff92SAndroid Build Coastguard Worker }
1556*89c4ff92SAndroid Build Coastguard Worker return importedInputs;
1557*89c4ff92SAndroid Build Coastguard Worker }
1558*89c4ff92SAndroid Build Coastguard Worker }
1559*89c4ff92SAndroid Build Coastguard Worker
ImportOutputs(const OutputTensors & outputTensors,MemorySource forceImportMemorySource)1560*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& outputTensors,
1561*89c4ff92SAndroid Build Coastguard Worker MemorySource forceImportMemorySource)
1562*89c4ff92SAndroid Build Coastguard Worker {
1563*89c4ff92SAndroid Build Coastguard Worker if (!m_NetworkProperties.m_AsyncEnabled)
1564*89c4ff92SAndroid Build Coastguard Worker {
1565*89c4ff92SAndroid Build Coastguard Worker // Cannot import if import is not enabled and forceImportMemorySource is undefined
1566*89c4ff92SAndroid Build Coastguard Worker if (forceImportMemorySource == MemorySource::Undefined)
1567*89c4ff92SAndroid Build Coastguard Worker {
1568*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportOutputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
1569*89c4ff92SAndroid Build Coastguard Worker }
1570*89c4ff92SAndroid Build Coastguard Worker // If forceImportMemorySource is defined, try import if memory is aligned
1571*89c4ff92SAndroid Build Coastguard Worker if (outputTensors.size() != m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumOutputs())
1572*89c4ff92SAndroid Build Coastguard Worker {
1573*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportOutputs: Force Import failed, incorrect number of tensors");
1574*89c4ff92SAndroid Build Coastguard Worker }
1575*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> importedOutputs;
1576*89c4ff92SAndroid Build Coastguard Worker Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1577*89c4ff92SAndroid Build Coastguard Worker
1578*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = 0;
1579*89c4ff92SAndroid Build Coastguard Worker for (const BindableLayer* const outputLayer : graph.GetOutputLayers())
1580*89c4ff92SAndroid Build Coastguard Worker {
1581*89c4ff92SAndroid Build Coastguard Worker auto inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
1582*89c4ff92SAndroid Build Coastguard Worker if (!inputTensorHandle)
1583*89c4ff92SAndroid Build Coastguard Worker {
1584*89c4ff92SAndroid Build Coastguard Worker outputIndex++;
1585*89c4ff92SAndroid Build Coastguard Worker continue;
1586*89c4ff92SAndroid Build Coastguard Worker }
1587*89c4ff92SAndroid Build Coastguard Worker
1588*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = outputLayer->GetBindingId();
1589*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(outputTensors.begin(), outputTensors.end(), [=] (const auto& outputTensor)
1590*89c4ff92SAndroid Build Coastguard Worker {
1591*89c4ff92SAndroid Build Coastguard Worker return outputTensor.first == layerBindingId;
1592*89c4ff92SAndroid Build Coastguard Worker });
1593*89c4ff92SAndroid Build Coastguard Worker
1594*89c4ff92SAndroid Build Coastguard Worker if (it == outputTensors.end())
1595*89c4ff92SAndroid Build Coastguard Worker {
1596*89c4ff92SAndroid Build Coastguard Worker outputIndex++;
1597*89c4ff92SAndroid Build Coastguard Worker continue;
1598*89c4ff92SAndroid Build Coastguard Worker }
1599*89c4ff92SAndroid Build Coastguard Worker
1600*89c4ff92SAndroid Build Coastguard Worker const auto outputTensor = *it;
1601*89c4ff92SAndroid Build Coastguard Worker try
1602*89c4ff92SAndroid Build Coastguard Worker {
1603*89c4ff92SAndroid Build Coastguard Worker // Check if the output memory can be imported
1604*89c4ff92SAndroid Build Coastguard Worker if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource)
1605*89c4ff92SAndroid Build Coastguard Worker && inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
1606*89c4ff92SAndroid Build Coastguard Worker {
1607*89c4ff92SAndroid Build Coastguard Worker importedOutputs.push_back(outputIndex);
1608*89c4ff92SAndroid Build Coastguard Worker }
1609*89c4ff92SAndroid Build Coastguard Worker }
1610*89c4ff92SAndroid Build Coastguard Worker catch(const MemoryImportException& exception)
1611*89c4ff92SAndroid Build Coastguard Worker {
1612*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "An error occurred attempting to import output_"
1613*89c4ff92SAndroid Build Coastguard Worker << outputIndex << " : " << exception.what();
1614*89c4ff92SAndroid Build Coastguard Worker }
1615*89c4ff92SAndroid Build Coastguard Worker outputIndex++;
1616*89c4ff92SAndroid Build Coastguard Worker }
1617*89c4ff92SAndroid Build Coastguard Worker return importedOutputs;
1618*89c4ff92SAndroid Build Coastguard Worker }
1619*89c4ff92SAndroid Build Coastguard Worker
1620*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> importedOutputs;
1621*89c4ff92SAndroid Build Coastguard Worker Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
1622*89c4ff92SAndroid Build Coastguard Worker
1623*89c4ff92SAndroid Build Coastguard Worker for (const auto& outputTensor : outputTensors)
1624*89c4ff92SAndroid Build Coastguard Worker {
1625*89c4ff92SAndroid Build Coastguard Worker auto layerBindingId = outputTensor.first;
1626*89c4ff92SAndroid Build Coastguard Worker auto it = std::find_if(graph.GetOutputLayers().begin(), graph.GetOutputLayers().end(), [=](auto* layer)
1627*89c4ff92SAndroid Build Coastguard Worker {
1628*89c4ff92SAndroid Build Coastguard Worker return layer->GetBindingId() == layerBindingId;
1629*89c4ff92SAndroid Build Coastguard Worker });
1630*89c4ff92SAndroid Build Coastguard Worker
1631*89c4ff92SAndroid Build Coastguard Worker if (it == graph.GetOutputLayers().end())
1632*89c4ff92SAndroid Build Coastguard Worker {
1633*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(fmt::format("ImportOutputs: Memory Import failed, unknown LayerBindingId: {}",
1634*89c4ff92SAndroid Build Coastguard Worker layerBindingId));
1635*89c4ff92SAndroid Build Coastguard Worker }
1636*89c4ff92SAndroid Build Coastguard Worker
1637*89c4ff92SAndroid Build Coastguard Worker const Layer* layer = *it;
1638*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() != LayerType::Output)
1639*89c4ff92SAndroid Build Coastguard Worker {
1640*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("ImportOutputs: given layer not an OutputLayer");
1641*89c4ff92SAndroid Build Coastguard Worker }
1642*89c4ff92SAndroid Build Coastguard Worker
1643*89c4ff92SAndroid Build Coastguard Worker auto& backend = m_Backends.at(layer->GetBackendId());
1644*89c4ff92SAndroid Build Coastguard Worker if (!HasCapability(BackendOptions::BackendOption{"PreImportIOTensors", true}, backend->GetCapabilities()))
1645*89c4ff92SAndroid Build Coastguard Worker {
1646*89c4ff92SAndroid Build Coastguard Worker std::string er = backend->GetId();
1647*89c4ff92SAndroid Build Coastguard Worker er += " does not have PreImportIOTensors capability";
1648*89c4ff92SAndroid Build Coastguard Worker throw BackendCapabilityException(er);
1649*89c4ff92SAndroid Build Coastguard Worker }
1650*89c4ff92SAndroid Build Coastguard Worker
1651*89c4ff92SAndroid Build Coastguard Worker const InputSlot& inputSlot = layer->GetInputSlots()[0];
1652*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId();
1653*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo();
1654*89c4ff92SAndroid Build Coastguard Worker
1655*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1656*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(handleFactory);
1657*89c4ff92SAndroid Build Coastguard Worker
1658*89c4ff92SAndroid Build Coastguard Worker ImportedTensorHandlePin importedTensorHandlePin{layerBindingId,
1659*89c4ff92SAndroid Build Coastguard Worker handleFactory->CreateTensorHandle(tensorInfo, false)};
1660*89c4ff92SAndroid Build Coastguard Worker
1661*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = importedTensorHandlePin.m_TensorHandle.get();
1662*89c4ff92SAndroid Build Coastguard Worker
1663*89c4ff92SAndroid Build Coastguard Worker if (!CheckFlag(tensorHandle->GetImportFlags(), forceImportMemorySource))
1664*89c4ff92SAndroid Build Coastguard Worker {
1665*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException(fmt::format("ImportInputs: Memory Import failed, backend: "
1666*89c4ff92SAndroid Build Coastguard Worker "{} does not support importing from source {}"
1667*89c4ff92SAndroid Build Coastguard Worker , factoryId, forceImportMemorySource));
1668*89c4ff92SAndroid Build Coastguard Worker }
1669*89c4ff92SAndroid Build Coastguard Worker
1670*89c4ff92SAndroid Build Coastguard Worker if (tensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
1671*89c4ff92SAndroid Build Coastguard Worker {
1672*89c4ff92SAndroid Build Coastguard Worker importedOutputs.push_back(m_CurImportedOutputId++);
1673*89c4ff92SAndroid Build Coastguard Worker }
1674*89c4ff92SAndroid Build Coastguard Worker else
1675*89c4ff92SAndroid Build Coastguard Worker {
1676*89c4ff92SAndroid Build Coastguard Worker throw MemoryImportException("ImportInputs: Memory Import failed");
1677*89c4ff92SAndroid Build Coastguard Worker }
1678*89c4ff92SAndroid Build Coastguard Worker
1679*89c4ff92SAndroid Build Coastguard Worker m_PreImportedOutputHandles.push_back(std::move(importedTensorHandlePin));
1680*89c4ff92SAndroid Build Coastguard Worker }
1681*89c4ff92SAndroid Build Coastguard Worker
1682*89c4ff92SAndroid Build Coastguard Worker return importedOutputs;
1683*89c4ff92SAndroid Build Coastguard Worker }
1684*89c4ff92SAndroid Build Coastguard Worker
ClearImportedInputs(const std::vector<ImportedInputId> inputIds)1685*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ClearImportedInputs(const std::vector<ImportedInputId> inputIds)
1686*89c4ff92SAndroid Build Coastguard Worker {
1687*89c4ff92SAndroid Build Coastguard Worker for (auto id : inputIds)
1688*89c4ff92SAndroid Build Coastguard Worker {
1689*89c4ff92SAndroid Build Coastguard Worker if (id > m_PreImportedInputHandles.size())
1690*89c4ff92SAndroid Build Coastguard Worker {
1691*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("ClearImportedInputs::Unknown ImportedInputId: {}", id));
1692*89c4ff92SAndroid Build Coastguard Worker }
1693*89c4ff92SAndroid Build Coastguard Worker
1694*89c4ff92SAndroid Build Coastguard Worker auto& importedTensorHandle = m_PreImportedInputHandles[id].m_TensorHandle;
1695*89c4ff92SAndroid Build Coastguard Worker if (!importedTensorHandle)
1696*89c4ff92SAndroid Build Coastguard Worker {
1697*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1698*89c4ff92SAndroid Build Coastguard Worker fmt::format("ClearImportedInputs::ImportedInput with id: {} has already been deleted", id));
1699*89c4ff92SAndroid Build Coastguard Worker }
1700*89c4ff92SAndroid Build Coastguard Worker // Call Unimport then destroy the tensorHandle
1701*89c4ff92SAndroid Build Coastguard Worker importedTensorHandle->Unimport();
1702*89c4ff92SAndroid Build Coastguard Worker importedTensorHandle = {};
1703*89c4ff92SAndroid Build Coastguard Worker }
1704*89c4ff92SAndroid Build Coastguard Worker }
1705*89c4ff92SAndroid Build Coastguard Worker
ClearImportedOutputs(const std::vector<ImportedOutputId> outputIds)1706*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::ClearImportedOutputs(const std::vector<ImportedOutputId> outputIds)
1707*89c4ff92SAndroid Build Coastguard Worker {
1708*89c4ff92SAndroid Build Coastguard Worker for (auto id : outputIds)
1709*89c4ff92SAndroid Build Coastguard Worker {
1710*89c4ff92SAndroid Build Coastguard Worker if (id > m_PreImportedOutputHandles.size())
1711*89c4ff92SAndroid Build Coastguard Worker {
1712*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("ClearImportedOutputs::Unknown ImportedOutputId: {}", id));
1713*89c4ff92SAndroid Build Coastguard Worker }
1714*89c4ff92SAndroid Build Coastguard Worker
1715*89c4ff92SAndroid Build Coastguard Worker auto& importedTensorHandle = m_PreImportedOutputHandles[id].m_TensorHandle;
1716*89c4ff92SAndroid Build Coastguard Worker if (!importedTensorHandle)
1717*89c4ff92SAndroid Build Coastguard Worker {
1718*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(
1719*89c4ff92SAndroid Build Coastguard Worker fmt::format("ClearImportedOutputs::ImportedOutput with id: {} has already been deleted", id));
1720*89c4ff92SAndroid Build Coastguard Worker }
1721*89c4ff92SAndroid Build Coastguard Worker // Call Unimport then destroy the tensorHandle
1722*89c4ff92SAndroid Build Coastguard Worker importedTensorHandle->Unimport();
1723*89c4ff92SAndroid Build Coastguard Worker importedTensorHandle = {};
1724*89c4ff92SAndroid Build Coastguard Worker }
1725*89c4ff92SAndroid Build Coastguard Worker }
1726*89c4ff92SAndroid Build Coastguard Worker
Execute(const InputTensors & inputTensors,const OutputTensors & outputTensors,IWorkingMemHandle & iWorkingMemHandle,std::vector<ImportedInputId> preImportedInputs,std::vector<ImportedOutputId> preImportedOutputs)1727*89c4ff92SAndroid Build Coastguard Worker Status LoadedNetwork::Execute(const InputTensors& inputTensors,
1728*89c4ff92SAndroid Build Coastguard Worker const OutputTensors& outputTensors,
1729*89c4ff92SAndroid Build Coastguard Worker IWorkingMemHandle& iWorkingMemHandle,
1730*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedInputId> preImportedInputs,
1731*89c4ff92SAndroid Build Coastguard Worker std::vector<ImportedOutputId> preImportedOutputs)
1732*89c4ff92SAndroid Build Coastguard Worker {
1733*89c4ff92SAndroid Build Coastguard Worker const Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
1734*89c4ff92SAndroid Build Coastguard Worker
1735*89c4ff92SAndroid Build Coastguard Worker if (inputTensors.size() + preImportedInputs.size() != graph.GetNumInputs())
1736*89c4ff92SAndroid Build Coastguard Worker {
1737*89c4ff92SAndroid Build Coastguard Worker if (preImportedInputs.empty())
1738*89c4ff92SAndroid Build Coastguard Worker {
1739*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("LoadedNetwork::Execute: Number of inputs provided does not match network.");
1740*89c4ff92SAndroid Build Coastguard Worker }
1741*89c4ff92SAndroid Build Coastguard Worker else
1742*89c4ff92SAndroid Build Coastguard Worker {
1743*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("LoadedNetwork::Execute: "
1744*89c4ff92SAndroid Build Coastguard Worker "Number of inputs + preImportedInputs provided does not match network.");
1745*89c4ff92SAndroid Build Coastguard Worker }
1746*89c4ff92SAndroid Build Coastguard Worker }
1747*89c4ff92SAndroid Build Coastguard Worker
1748*89c4ff92SAndroid Build Coastguard Worker if (outputTensors.size() + preImportedOutputs.size() != graph.GetNumOutputs())
1749*89c4ff92SAndroid Build Coastguard Worker {
1750*89c4ff92SAndroid Build Coastguard Worker if (preImportedOutputs.empty())
1751*89c4ff92SAndroid Build Coastguard Worker {
1752*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("LoadedNetwork::Execute: "
1753*89c4ff92SAndroid Build Coastguard Worker "Number of outputs provided does not match network.");
1754*89c4ff92SAndroid Build Coastguard Worker }
1755*89c4ff92SAndroid Build Coastguard Worker else
1756*89c4ff92SAndroid Build Coastguard Worker {
1757*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("LoadedNetwork::Execute: "
1758*89c4ff92SAndroid Build Coastguard Worker "Number of outputs + preImportedOutputs provided does not match network.");
1759*89c4ff92SAndroid Build Coastguard Worker }
1760*89c4ff92SAndroid Build Coastguard Worker }
1761*89c4ff92SAndroid Build Coastguard Worker
1762*89c4ff92SAndroid Build Coastguard Worker WorkingMemHandle& workingMemHandle = dynamic_cast<WorkingMemHandle&>(iWorkingMemHandle);
1763*89c4ff92SAndroid Build Coastguard Worker // Collect all the given LayerBindingIds and check them for duplicates and unknowns.
1764*89c4ff92SAndroid Build Coastguard Worker std::vector<LayerBindingId>& bindingIds = workingMemHandle.GetBindingIdVector();
1765*89c4ff92SAndroid Build Coastguard Worker unsigned int index = 0;
1766*89c4ff92SAndroid Build Coastguard Worker for (auto pair : inputTensors)
1767*89c4ff92SAndroid Build Coastguard Worker {
1768*89c4ff92SAndroid Build Coastguard Worker bindingIds[index++] = pair.first;
1769*89c4ff92SAndroid Build Coastguard Worker }
1770*89c4ff92SAndroid Build Coastguard Worker for (ImportedInputId id : preImportedInputs)
1771*89c4ff92SAndroid Build Coastguard Worker {
1772*89c4ff92SAndroid Build Coastguard Worker bindingIds[index++] = ValidateImportedInputID(id);
1773*89c4ff92SAndroid Build Coastguard Worker }
1774*89c4ff92SAndroid Build Coastguard Worker for (auto pair : outputTensors)
1775*89c4ff92SAndroid Build Coastguard Worker {
1776*89c4ff92SAndroid Build Coastguard Worker bindingIds[index++] = pair.first;
1777*89c4ff92SAndroid Build Coastguard Worker }
1778*89c4ff92SAndroid Build Coastguard Worker for (ImportedOutputId id : preImportedOutputs)
1779*89c4ff92SAndroid Build Coastguard Worker {
1780*89c4ff92SAndroid Build Coastguard Worker bindingIds[index++] = ValidateImportedOutputID(id);
1781*89c4ff92SAndroid Build Coastguard Worker }
1782*89c4ff92SAndroid Build Coastguard Worker
1783*89c4ff92SAndroid Build Coastguard Worker workingMemHandle.ValidateBindingIds();
1784*89c4ff92SAndroid Build Coastguard Worker
1785*89c4ff92SAndroid Build Coastguard Worker auto resetMemHandle = [&]()
1786*89c4ff92SAndroid Build Coastguard Worker {
1787*89c4ff92SAndroid Build Coastguard Worker for (ImportedInputId id: preImportedInputs)
1788*89c4ff92SAndroid Build Coastguard Worker {
1789*89c4ff92SAndroid Build Coastguard Worker const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
1790*89c4ff92SAndroid Build Coastguard Worker
1791*89c4ff92SAndroid Build Coastguard Worker auto inputHandle = workingMemHandle.GetInputHandle(layerBindingId);
1792*89c4ff92SAndroid Build Coastguard Worker auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
1793*89c4ff92SAndroid Build Coastguard Worker for (auto it : inputConnections)
1794*89c4ff92SAndroid Build Coastguard Worker {
1795*89c4ff92SAndroid Build Coastguard Worker *it = inputHandle;
1796*89c4ff92SAndroid Build Coastguard Worker }
1797*89c4ff92SAndroid Build Coastguard Worker }
1798*89c4ff92SAndroid Build Coastguard Worker
1799*89c4ff92SAndroid Build Coastguard Worker for (ImportedOutputId id: preImportedOutputs)
1800*89c4ff92SAndroid Build Coastguard Worker {
1801*89c4ff92SAndroid Build Coastguard Worker const LayerBindingId layerBindingId = m_PreImportedOutputHandles[id].m_LayerBindingId;
1802*89c4ff92SAndroid Build Coastguard Worker
1803*89c4ff92SAndroid Build Coastguard Worker auto outputHandle = workingMemHandle.GetOutputHandle(layerBindingId);
1804*89c4ff92SAndroid Build Coastguard Worker auto outputConnections = workingMemHandle.GetOutputConnection(layerBindingId);
1805*89c4ff92SAndroid Build Coastguard Worker
1806*89c4ff92SAndroid Build Coastguard Worker for (auto it : outputConnections)
1807*89c4ff92SAndroid Build Coastguard Worker {
1808*89c4ff92SAndroid Build Coastguard Worker *it = outputHandle;
1809*89c4ff92SAndroid Build Coastguard Worker }
1810*89c4ff92SAndroid Build Coastguard Worker }
1811*89c4ff92SAndroid Build Coastguard Worker };
1812*89c4ff92SAndroid Build Coastguard Worker
1813*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<TimelineUtilityMethods> timelineUtils =
1814*89c4ff92SAndroid Build Coastguard Worker TimelineUtilityMethods::GetTimelineUtils(*m_ProfilingService);
1815*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid inferenceGuid = m_ProfilingService->GetNextGuid();
1816*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1817*89c4ff92SAndroid Build Coastguard Worker {
1818*89c4ff92SAndroid Build Coastguard Worker // Add inference timeline trace if profiling is enabled.
1819*89c4ff92SAndroid Build Coastguard Worker ProfilingGuid networkGuid = m_OptimizedNetwork->GetGuid();
1820*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateTypedEntity(inferenceGuid,LabelsAndEventClasses::INFERENCE_GUID);
1821*89c4ff92SAndroid Build Coastguard Worker timelineUtils->CreateRelationship(ProfilingRelationshipType::RetentionLink,
1822*89c4ff92SAndroid Build Coastguard Worker networkGuid,
1823*89c4ff92SAndroid Build Coastguard Worker inferenceGuid,
1824*89c4ff92SAndroid Build Coastguard Worker LabelsAndEventClasses::EXECUTION_OF_GUID);
1825*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEvent(inferenceGuid,LabelsAndEventClasses::ARMNN_PROFILING_SOL_EVENT_CLASS);
1826*89c4ff92SAndroid Build Coastguard Worker }
1827*89c4ff92SAndroid Build Coastguard Worker
1828*89c4ff92SAndroid Build Coastguard Worker bool executionSucceeded = true;
1829*89c4ff92SAndroid Build Coastguard Worker
1830*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1831*89c4ff92SAndroid Build Coastguard Worker {
1832*89c4ff92SAndroid Build Coastguard Worker // Add end of life of the inference timeline if profiling is enabled.
1833*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEvent(inferenceGuid,LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
1834*89c4ff92SAndroid Build Coastguard Worker timelineUtils->Commit();
1835*89c4ff92SAndroid Build Coastguard Worker }
1836*89c4ff92SAndroid Build Coastguard Worker
1837*89c4ff92SAndroid Build Coastguard Worker if (!workingMemHandle.IsAllocated())
1838*89c4ff92SAndroid Build Coastguard Worker {
1839*89c4ff92SAndroid Build Coastguard Worker workingMemHandle.Allocate();
1840*89c4ff92SAndroid Build Coastguard Worker }
1841*89c4ff92SAndroid Build Coastguard Worker
1842*89c4ff92SAndroid Build Coastguard Worker {
1843*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareInputs");
1844*89c4ff92SAndroid Build Coastguard Worker for (auto pair : inputTensors)
1845*89c4ff92SAndroid Build Coastguard Worker {
1846*89c4ff92SAndroid Build Coastguard Worker EnqueueInput(pair.second, workingMemHandle.GetInputHandle(pair.first));
1847*89c4ff92SAndroid Build Coastguard Worker }
1848*89c4ff92SAndroid Build Coastguard Worker
1849*89c4ff92SAndroid Build Coastguard Worker // Swap in the pre-imported inputs if any
1850*89c4ff92SAndroid Build Coastguard Worker for (ImportedInputId id : preImportedInputs)
1851*89c4ff92SAndroid Build Coastguard Worker {
1852*89c4ff92SAndroid Build Coastguard Worker const ImportedTensorHandlePin& importedInputPin = m_PreImportedInputHandles[id];
1853*89c4ff92SAndroid Build Coastguard Worker const LayerBindingId layerBindingId = m_PreImportedInputHandles[id].m_LayerBindingId;
1854*89c4ff92SAndroid Build Coastguard Worker const auto& preimportedHandle = importedInputPin.m_TensorHandle;
1855*89c4ff92SAndroid Build Coastguard Worker
1856*89c4ff92SAndroid Build Coastguard Worker auto inputConnections = workingMemHandle.GetInputConnections(layerBindingId);
1857*89c4ff92SAndroid Build Coastguard Worker for (auto it : inputConnections)
1858*89c4ff92SAndroid Build Coastguard Worker {
1859*89c4ff92SAndroid Build Coastguard Worker *it = preimportedHandle.get();
1860*89c4ff92SAndroid Build Coastguard Worker }
1861*89c4ff92SAndroid Build Coastguard Worker }
1862*89c4ff92SAndroid Build Coastguard Worker }
1863*89c4ff92SAndroid Build Coastguard Worker {
1864*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs");
1865*89c4ff92SAndroid Build Coastguard Worker if (m_NetworkProperties.m_ExportEnabled)
1866*89c4ff92SAndroid Build Coastguard Worker {
1867*89c4ff92SAndroid Build Coastguard Worker for (auto pair: outputTensors)
1868*89c4ff92SAndroid Build Coastguard Worker {
1869*89c4ff92SAndroid Build Coastguard Worker ImportOutputTensor(pair.second, workingMemHandle.GetOutputHandle(pair.first));
1870*89c4ff92SAndroid Build Coastguard Worker }
1871*89c4ff92SAndroid Build Coastguard Worker }
1872*89c4ff92SAndroid Build Coastguard Worker
1873*89c4ff92SAndroid Build Coastguard Worker for (ImportedOutputId id : preImportedOutputs)
1874*89c4ff92SAndroid Build Coastguard Worker {
1875*89c4ff92SAndroid Build Coastguard Worker const ImportedTensorHandlePin& importedOutputPin = m_PreImportedOutputHandles[id];
1876*89c4ff92SAndroid Build Coastguard Worker const LayerBindingId layerBindingId = m_PreImportedOutputHandles[id].m_LayerBindingId;
1877*89c4ff92SAndroid Build Coastguard Worker const auto& preimportedHandle = importedOutputPin.m_TensorHandle;
1878*89c4ff92SAndroid Build Coastguard Worker
1879*89c4ff92SAndroid Build Coastguard Worker auto outputConnections = workingMemHandle.GetOutputConnection(layerBindingId);
1880*89c4ff92SAndroid Build Coastguard Worker for (auto it : outputConnections)
1881*89c4ff92SAndroid Build Coastguard Worker {
1882*89c4ff92SAndroid Build Coastguard Worker *it = preimportedHandle.get();
1883*89c4ff92SAndroid Build Coastguard Worker }
1884*89c4ff92SAndroid Build Coastguard Worker }
1885*89c4ff92SAndroid Build Coastguard Worker }
1886*89c4ff92SAndroid Build Coastguard Worker
1887*89c4ff92SAndroid Build Coastguard Worker auto Fail = [&](const std::exception& error)
1888*89c4ff92SAndroid Build Coastguard Worker {
1889*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(error) << "An error occurred attempting to execute a workload: " << error.what();
1890*89c4ff92SAndroid Build Coastguard Worker executionSucceeded = false;
1891*89c4ff92SAndroid Build Coastguard Worker };
1892*89c4ff92SAndroid Build Coastguard Worker ProfilingDynamicGuid workloadInferenceID(0);
1893*89c4ff92SAndroid Build Coastguard Worker
1894*89c4ff92SAndroid Build Coastguard Worker try
1895*89c4ff92SAndroid Build Coastguard Worker {
1896*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_WorkloadQueue.size(); ++i)
1897*89c4ff92SAndroid Build Coastguard Worker {
1898*89c4ff92SAndroid Build Coastguard Worker auto& workload = m_WorkloadQueue[i];
1899*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1900*89c4ff92SAndroid Build Coastguard Worker {
1901*89c4ff92SAndroid Build Coastguard Worker workloadInferenceID = timelineUtils->RecordWorkloadInferenceAndStartOfLifeEvent(workload->GetGuid(),
1902*89c4ff92SAndroid Build Coastguard Worker inferenceGuid);
1903*89c4ff92SAndroid Build Coastguard Worker }
1904*89c4ff92SAndroid Build Coastguard Worker
1905*89c4ff92SAndroid Build Coastguard Worker workload->ExecuteAsync(workingMemHandle.GetExecutionDataAt(i).second);
1906*89c4ff92SAndroid Build Coastguard Worker
1907*89c4ff92SAndroid Build Coastguard Worker if (timelineUtils)
1908*89c4ff92SAndroid Build Coastguard Worker {
1909*89c4ff92SAndroid Build Coastguard Worker timelineUtils->RecordEndOfLifeEvent(workloadInferenceID);
1910*89c4ff92SAndroid Build Coastguard Worker }
1911*89c4ff92SAndroid Build Coastguard Worker }
1912*89c4ff92SAndroid Build Coastguard Worker }
1913*89c4ff92SAndroid Build Coastguard Worker catch (const RuntimeException& error)
1914*89c4ff92SAndroid Build Coastguard Worker {
1915*89c4ff92SAndroid Build Coastguard Worker resetMemHandle();
1916*89c4ff92SAndroid Build Coastguard Worker Fail(error);
1917*89c4ff92SAndroid Build Coastguard Worker }
1918*89c4ff92SAndroid Build Coastguard Worker catch (const std::runtime_error& error)
1919*89c4ff92SAndroid Build Coastguard Worker {
1920*89c4ff92SAndroid Build Coastguard Worker resetMemHandle();
1921*89c4ff92SAndroid Build Coastguard Worker Fail(error);
1922*89c4ff92SAndroid Build Coastguard Worker }
1923*89c4ff92SAndroid Build Coastguard Worker catch (...)
1924*89c4ff92SAndroid Build Coastguard Worker {
1925*89c4ff92SAndroid Build Coastguard Worker resetMemHandle();
1926*89c4ff92SAndroid Build Coastguard Worker throw;
1927*89c4ff92SAndroid Build Coastguard Worker }
1928*89c4ff92SAndroid Build Coastguard Worker
1929*89c4ff92SAndroid Build Coastguard Worker if (!m_NetworkProperties.m_ExportEnabled)
1930*89c4ff92SAndroid Build Coastguard Worker {
1931*89c4ff92SAndroid Build Coastguard Worker for (auto pair: outputTensors)
1932*89c4ff92SAndroid Build Coastguard Worker {
1933*89c4ff92SAndroid Build Coastguard Worker CopyToOutputTensor(pair.second, workingMemHandle.GetOutputHandle(pair.first));
1934*89c4ff92SAndroid Build Coastguard Worker }
1935*89c4ff92SAndroid Build Coastguard Worker }
1936*89c4ff92SAndroid Build Coastguard Worker else
1937*89c4ff92SAndroid Build Coastguard Worker {
1938*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "SyncMemGeneric_Execute");
1939*89c4ff92SAndroid Build Coastguard Worker workingMemHandle.MemSyncOutputs();
1940*89c4ff92SAndroid Build Coastguard Worker }
1941*89c4ff92SAndroid Build Coastguard Worker
1942*89c4ff92SAndroid Build Coastguard Worker resetMemHandle();
1943*89c4ff92SAndroid Build Coastguard Worker
1944*89c4ff92SAndroid Build Coastguard Worker return executionSucceeded ? Status::Success : Status::Failure;
1945*89c4ff92SAndroid Build Coastguard Worker }
1946*89c4ff92SAndroid Build Coastguard Worker
1947*89c4ff92SAndroid Build Coastguard Worker /// Create a new unique WorkingMemHandle object. Create multiple handles if you wish to have
1948*89c4ff92SAndroid Build Coastguard Worker /// overlapped Execution by calling this function from different threads.
CreateWorkingMemHandle(NetworkId networkId)1949*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(NetworkId networkId)
1950*89c4ff92SAndroid Build Coastguard Worker {
1951*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph();
1952*89c4ff92SAndroid Build Coastguard Worker
1953*89c4ff92SAndroid Build Coastguard Worker // Tensors that will need to be allocated internally within armnn
1954*89c4ff92SAndroid Build Coastguard Worker std::vector<std::unique_ptr<ITensorHandle>> managedTensorHandles;
1955*89c4ff92SAndroid Build Coastguard Worker // Tensors that will be allocated externally by the user
1956*89c4ff92SAndroid Build Coastguard Worker std::vector<std::unique_ptr<ITensorHandle>> unmanagedTensorHandles;
1957*89c4ff92SAndroid Build Coastguard Worker
1958*89c4ff92SAndroid Build Coastguard Worker std::vector<WorkingMemDescriptor> workingMemDescriptors;
1959*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<BackendId, ExecutionData>> executionDataVec;
1960*89c4ff92SAndroid Build Coastguard Worker
1961*89c4ff92SAndroid Build Coastguard Worker auto GetTensorHandle = [&](Layer* layer, const OutputSlot& outputSlot)
1962*89c4ff92SAndroid Build Coastguard Worker {
1963*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
1964*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1965*89c4ff92SAndroid Build Coastguard Worker
1966*89c4ff92SAndroid Build Coastguard Worker if (factoryId == ITensorHandleFactory::LegacyFactoryId)
1967*89c4ff92SAndroid Build Coastguard Worker {
1968*89c4ff92SAndroid Build Coastguard Worker BackendId id = layer->GetBackendId();
1969*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
1970*89c4ff92SAndroid Build Coastguard Worker return m_WorkloadFactories.at(id)->CreateTensorHandle(tensorInfo, false);
1971*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
1972*89c4ff92SAndroid Build Coastguard Worker }
1973*89c4ff92SAndroid Build Coastguard Worker else
1974*89c4ff92SAndroid Build Coastguard Worker {
1975*89c4ff92SAndroid Build Coastguard Worker ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
1976*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(handleFactory);
1977*89c4ff92SAndroid Build Coastguard Worker return handleFactory->CreateTensorHandle(tensorInfo, false);
1978*89c4ff92SAndroid Build Coastguard Worker }
1979*89c4ff92SAndroid Build Coastguard Worker };
1980*89c4ff92SAndroid Build Coastguard Worker
1981*89c4ff92SAndroid Build Coastguard Worker struct HandleInfo
1982*89c4ff92SAndroid Build Coastguard Worker {
1983*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* m_TensorHandle;
1984*89c4ff92SAndroid Build Coastguard Worker
1985*89c4ff92SAndroid Build Coastguard Worker bool m_IsInputLayerHandle = false;
1986*89c4ff92SAndroid Build Coastguard Worker bool m_IsOutputLayerHandle = false;
1987*89c4ff92SAndroid Build Coastguard Worker
1988*89c4ff92SAndroid Build Coastguard Worker WorkingMemHandle::InputMemDescriptorCoords m_InputMemDescriptorCoords;
1989*89c4ff92SAndroid Build Coastguard Worker WorkingMemHandle::OutputMemDescriptorCoords m_OutputMemDescriptorCoords;
1990*89c4ff92SAndroid Build Coastguard Worker };
1991*89c4ff92SAndroid Build Coastguard Worker
1992*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<const OutputSlot*, HandleInfo> outputToHandleInfoMap;
1993*89c4ff92SAndroid Build Coastguard Worker
1994*89c4ff92SAndroid Build Coastguard Worker unsigned int layerIndex = 0;
1995*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
1996*89c4ff92SAndroid Build Coastguard Worker {
1997*89c4ff92SAndroid Build Coastguard Worker // Constant layers execution and management is handled during loaded network construction
1998*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::Constant)
1999*89c4ff92SAndroid Build Coastguard Worker {
2000*89c4ff92SAndroid Build Coastguard Worker continue;
2001*89c4ff92SAndroid Build Coastguard Worker }
2002*89c4ff92SAndroid Build Coastguard Worker
2003*89c4ff92SAndroid Build Coastguard Worker WorkingMemDescriptor workingMemDescriptor;
2004*89c4ff92SAndroid Build Coastguard Worker
2005*89c4ff92SAndroid Build Coastguard Worker bool isMemoryManaged = true;
2006*89c4ff92SAndroid Build Coastguard Worker bool isInputLayer = false;
2007*89c4ff92SAndroid Build Coastguard Worker bool isOutputLayer = false;
2008*89c4ff92SAndroid Build Coastguard Worker bool isConnectedToOutputLayer = false;
2009*89c4ff92SAndroid Build Coastguard Worker
2010*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == LayerType::Input || layer->GetType() == LayerType::MemImport)
2011*89c4ff92SAndroid Build Coastguard Worker {
2012*89c4ff92SAndroid Build Coastguard Worker // Input layers/workloads will not be executed so the descriptor is not added to workingMemDescriptors
2013*89c4ff92SAndroid Build Coastguard Worker // However we will still need to manage the tensorHandle
2014*89c4ff92SAndroid Build Coastguard Worker isInputLayer = true;
2015*89c4ff92SAndroid Build Coastguard Worker isMemoryManaged = !m_NetworkProperties.m_ImportEnabled;
2016*89c4ff92SAndroid Build Coastguard Worker }
2017*89c4ff92SAndroid Build Coastguard Worker else if (layer->GetType() == LayerType::Output)
2018*89c4ff92SAndroid Build Coastguard Worker {
2019*89c4ff92SAndroid Build Coastguard Worker isOutputLayer = true;
2020*89c4ff92SAndroid Build Coastguard Worker }
2021*89c4ff92SAndroid Build Coastguard Worker
2022*89c4ff92SAndroid Build Coastguard Worker unsigned int slotIndex = 0;
2023*89c4ff92SAndroid Build Coastguard Worker // Create a tensor handle for each output slot of a layer
2024*89c4ff92SAndroid Build Coastguard Worker // Once we create it, we start managing its lifetime
2025*89c4ff92SAndroid Build Coastguard Worker for (auto& slot : layer->GetOutputSlots())
2026*89c4ff92SAndroid Build Coastguard Worker {
2027*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < slot.GetNumConnections(); ++i)
2028*89c4ff92SAndroid Build Coastguard Worker {
2029*89c4ff92SAndroid Build Coastguard Worker if ((slot.GetConnection(i)->GetOwningLayer().GetType() == LayerType::Output))
2030*89c4ff92SAndroid Build Coastguard Worker {
2031*89c4ff92SAndroid Build Coastguard Worker if (!isConnectedToOutputLayer)
2032*89c4ff92SAndroid Build Coastguard Worker {
2033*89c4ff92SAndroid Build Coastguard Worker isConnectedToOutputLayer = true;
2034*89c4ff92SAndroid Build Coastguard Worker // If Export is enabled disable memory management, so we can export, otherwise we do a copy
2035*89c4ff92SAndroid Build Coastguard Worker isMemoryManaged = !m_NetworkProperties.m_ExportEnabled;
2036*89c4ff92SAndroid Build Coastguard Worker }
2037*89c4ff92SAndroid Build Coastguard Worker else
2038*89c4ff92SAndroid Build Coastguard Worker {
2039*89c4ff92SAndroid Build Coastguard Worker // Importing in this case would likely cause unexpected behaviour, so we disallow it.
2040*89c4ff92SAndroid Build Coastguard Worker ARMNN_LOG(warning) <<
2041*89c4ff92SAndroid Build Coastguard Worker fmt::format("Layer name: '{0}' guid: '{1}' has two or more OutputLayers connected to it. "
2042*89c4ff92SAndroid Build Coastguard Worker "This will prevent importing on the connected OutputLayers.",
2043*89c4ff92SAndroid Build Coastguard Worker layer->GetName(), layer->GetGuid());
2044*89c4ff92SAndroid Build Coastguard Worker isMemoryManaged = true;
2045*89c4ff92SAndroid Build Coastguard Worker }
2046*89c4ff92SAndroid Build Coastguard Worker }
2047*89c4ff92SAndroid Build Coastguard Worker }
2048*89c4ff92SAndroid Build Coastguard Worker
2049*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle;
2050*89c4ff92SAndroid Build Coastguard Worker if (isMemoryManaged)
2051*89c4ff92SAndroid Build Coastguard Worker {
2052*89c4ff92SAndroid Build Coastguard Worker managedTensorHandles.emplace_back(GetTensorHandle(layer, slot));
2053*89c4ff92SAndroid Build Coastguard Worker tensorHandle = managedTensorHandles.back().get();
2054*89c4ff92SAndroid Build Coastguard Worker }
2055*89c4ff92SAndroid Build Coastguard Worker else
2056*89c4ff92SAndroid Build Coastguard Worker {
2057*89c4ff92SAndroid Build Coastguard Worker unmanagedTensorHandles.emplace_back(GetTensorHandle(layer, slot));
2058*89c4ff92SAndroid Build Coastguard Worker tensorHandle = unmanagedTensorHandles.back().get();
2059*89c4ff92SAndroid Build Coastguard Worker }
2060*89c4ff92SAndroid Build Coastguard Worker
2061*89c4ff92SAndroid Build Coastguard Worker workingMemDescriptor.m_Outputs.push_back(tensorHandle);
2062*89c4ff92SAndroid Build Coastguard Worker
2063*89c4ff92SAndroid Build Coastguard Worker HandleInfo& handleInfo = outputToHandleInfoMap[&slot];
2064*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_TensorHandle = tensorHandle;
2065*89c4ff92SAndroid Build Coastguard Worker
2066*89c4ff92SAndroid Build Coastguard Worker // Store the coordinates of the current layer's OutputSlot that is connected to the OutputLayer
2067*89c4ff92SAndroid Build Coastguard Worker if (isConnectedToOutputLayer)
2068*89c4ff92SAndroid Build Coastguard Worker {
2069*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_IsOutputLayerHandle = true;
2070*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_OutputSlotCoords = {layerIndex, slotIndex};
2071*89c4ff92SAndroid Build Coastguard Worker }
2072*89c4ff92SAndroid Build Coastguard Worker // Store the LayerBindingId of the InputLayer
2073*89c4ff92SAndroid Build Coastguard Worker if (isInputLayer)
2074*89c4ff92SAndroid Build Coastguard Worker {
2075*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_IsInputLayerHandle = true;
2076*89c4ff92SAndroid Build Coastguard Worker LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2077*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_InputMemDescriptorCoords.m_LayerBindingId = bindingId;
2078*89c4ff92SAndroid Build Coastguard Worker }
2079*89c4ff92SAndroid Build Coastguard Worker slotIndex++;
2080*89c4ff92SAndroid Build Coastguard Worker }
2081*89c4ff92SAndroid Build Coastguard Worker // Loop through the input slots in the same layer and decrement the reference counter associated
2082*89c4ff92SAndroid Build Coastguard Worker // to each tensor handle we encounter.
2083*89c4ff92SAndroid Build Coastguard Worker // Once it reaches zero, the lifetime of the tensor handle has ended, and we mark its memory as available
2084*89c4ff92SAndroid Build Coastguard Worker // so that the next tensor handle with a non overlapping lifetime can share its memory.
2085*89c4ff92SAndroid Build Coastguard Worker for (auto& slot : layer->GetInputSlots())
2086*89c4ff92SAndroid Build Coastguard Worker {
2087*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(slot.GetConnection());
2088*89c4ff92SAndroid Build Coastguard Worker auto outputSlot = slot.GetConnectedOutputSlot();
2089*89c4ff92SAndroid Build Coastguard Worker auto key = outputSlot->GetOwningLayer().GetGuid();
2090*89c4ff92SAndroid Build Coastguard Worker
2091*89c4ff92SAndroid Build Coastguard Worker // Constant layers execution and management is handled during loaded network construction
2092*89c4ff92SAndroid Build Coastguard Worker auto found = m_ConstantTensorHandles.find(key);
2093*89c4ff92SAndroid Build Coastguard Worker if (found != m_ConstantTensorHandles.end())
2094*89c4ff92SAndroid Build Coastguard Worker {
2095*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = found->second;
2096*89c4ff92SAndroid Build Coastguard Worker workingMemDescriptor.m_Inputs.push_back(tensorHandle);
2097*89c4ff92SAndroid Build Coastguard Worker
2098*89c4ff92SAndroid Build Coastguard Worker // Odd case where a constant layer is connected to an output layer
2099*89c4ff92SAndroid Build Coastguard Worker // We will need to create a HandleInfo to track it
2100*89c4ff92SAndroid Build Coastguard Worker if (isOutputLayer)
2101*89c4ff92SAndroid Build Coastguard Worker {
2102*89c4ff92SAndroid Build Coastguard Worker LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2103*89c4ff92SAndroid Build Coastguard Worker
2104*89c4ff92SAndroid Build Coastguard Worker HandleInfo& handleInfo = outputToHandleInfoMap[outputSlot];
2105*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_TensorHandle = tensorHandle;
2106*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_IsOutputLayerHandle = true;
2107*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
2108*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
2109*89c4ff92SAndroid Build Coastguard Worker }
2110*89c4ff92SAndroid Build Coastguard Worker continue;
2111*89c4ff92SAndroid Build Coastguard Worker }
2112*89c4ff92SAndroid Build Coastguard Worker
2113*89c4ff92SAndroid Build Coastguard Worker HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot);
2114*89c4ff92SAndroid Build Coastguard Worker
2115*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle;
2116*89c4ff92SAndroid Build Coastguard Worker workingMemDescriptor.m_Inputs.push_back(inputTensorHandle);
2117*89c4ff92SAndroid Build Coastguard Worker
2118*89c4ff92SAndroid Build Coastguard Worker // Store the LayerBindingId of the OutputLayer
2119*89c4ff92SAndroid Build Coastguard Worker if (isOutputLayer)
2120*89c4ff92SAndroid Build Coastguard Worker {
2121*89c4ff92SAndroid Build Coastguard Worker LayerBindingId bindingId = static_cast<BindableLayer*>(layer)->GetBindingId();
2122*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_LayerBindingIds.push_back(bindingId);
2123*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, 0});
2124*89c4ff92SAndroid Build Coastguard Worker }
2125*89c4ff92SAndroid Build Coastguard Worker // In this case the layer is not an Output Layer but shares its input tensorhandle with an OutputLayer
2126*89c4ff92SAndroid Build Coastguard Worker // It will need to be updated as well, if we swap out the tensorhandle
2127*89c4ff92SAndroid Build Coastguard Worker else if (handleInfo.m_IsOutputLayerHandle)
2128*89c4ff92SAndroid Build Coastguard Worker {
2129*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_OutputMemDescriptorCoords.m_InputSlotCoords.push_back({layerIndex, slot.GetSlotIndex()});
2130*89c4ff92SAndroid Build Coastguard Worker }
2131*89c4ff92SAndroid Build Coastguard Worker
2132*89c4ff92SAndroid Build Coastguard Worker // Store the coordinates of the InputSlots connected to the InputLayer
2133*89c4ff92SAndroid Build Coastguard Worker // There can be more than one InputSlot connected to an InputLayer, so we use a vector
2134*89c4ff92SAndroid Build Coastguard Worker if (handleInfo.m_IsInputLayerHandle)
2135*89c4ff92SAndroid Build Coastguard Worker {
2136*89c4ff92SAndroid Build Coastguard Worker std::pair<LayerGuid, unsigned int> connectionLocation{layerIndex, slot.GetSlotIndex()};
2137*89c4ff92SAndroid Build Coastguard Worker handleInfo.m_InputMemDescriptorCoords.m_InputSlotCoords.emplace_back(connectionLocation);
2138*89c4ff92SAndroid Build Coastguard Worker }
2139*89c4ff92SAndroid Build Coastguard Worker }
2140*89c4ff92SAndroid Build Coastguard Worker
2141*89c4ff92SAndroid Build Coastguard Worker // Input/Output layers/workloads will not be executed, so the descriptor is not added to workingMemDescriptors
2142*89c4ff92SAndroid Build Coastguard Worker // However we will still need to manage the tensorHandle
2143*89c4ff92SAndroid Build Coastguard Worker if (!isInputLayer)
2144*89c4ff92SAndroid Build Coastguard Worker {
2145*89c4ff92SAndroid Build Coastguard Worker // Simply auto initialise ExecutionData here, so it's added only for the layer that require execution.
2146*89c4ff92SAndroid Build Coastguard Worker // The memory and data will be allocated/assigned for the void* in WorkingMemHandle::Allocate.
2147*89c4ff92SAndroid Build Coastguard Worker std::pair<BackendId, ExecutionData> dataPair;
2148*89c4ff92SAndroid Build Coastguard Worker dataPair.first = layer->GetBackendId();
2149*89c4ff92SAndroid Build Coastguard Worker
2150*89c4ff92SAndroid Build Coastguard Worker executionDataVec.push_back(dataPair);
2151*89c4ff92SAndroid Build Coastguard Worker workingMemDescriptors.push_back(workingMemDescriptor);
2152*89c4ff92SAndroid Build Coastguard Worker
2153*89c4ff92SAndroid Build Coastguard Worker layerIndex++;
2154*89c4ff92SAndroid Build Coastguard Worker }
2155*89c4ff92SAndroid Build Coastguard Worker }
2156*89c4ff92SAndroid Build Coastguard Worker
2157*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<std::shared_ptr<TensorMemory>, MemorySource>> tensorMemory;
2158*89c4ff92SAndroid Build Coastguard Worker
2159*89c4ff92SAndroid Build Coastguard Worker auto externalMemoryManager = CreateExternalMemoryManger(tensorMemory);
2160*89c4ff92SAndroid Build Coastguard Worker
2161*89c4ff92SAndroid Build Coastguard Worker // Sort m_TensorMemory, so it's order matches the outputSlot order
2162*89c4ff92SAndroid Build Coastguard Worker std::sort(tensorMemory.begin(), tensorMemory.end(),
2163*89c4ff92SAndroid Build Coastguard Worker [](const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& lhs,
2164*89c4ff92SAndroid Build Coastguard Worker const std::pair<std::shared_ptr<TensorMemory>, MemorySource>& rhs)
2165*89c4ff92SAndroid Build Coastguard Worker {
2166*89c4ff92SAndroid Build Coastguard Worker return lhs.first->m_OutputSlotId < rhs.first->m_OutputSlotId;
2167*89c4ff92SAndroid Build Coastguard Worker });
2168*89c4ff92SAndroid Build Coastguard Worker
2169*89c4ff92SAndroid Build Coastguard Worker std::vector<WorkingMemHandle::InputMemDescriptorCoords> inputConnectionsInfo;
2170*89c4ff92SAndroid Build Coastguard Worker std::vector<WorkingMemHandle::OutputMemDescriptorCoords> outputConnectionsInfo;
2171*89c4ff92SAndroid Build Coastguard Worker
2172*89c4ff92SAndroid Build Coastguard Worker for (const auto& handleInfo: outputToHandleInfoMap)
2173*89c4ff92SAndroid Build Coastguard Worker {
2174*89c4ff92SAndroid Build Coastguard Worker if (handleInfo.second.m_IsOutputLayerHandle)
2175*89c4ff92SAndroid Build Coastguard Worker {
2176*89c4ff92SAndroid Build Coastguard Worker outputConnectionsInfo.emplace_back(handleInfo.second.m_OutputMemDescriptorCoords);
2177*89c4ff92SAndroid Build Coastguard Worker }
2178*89c4ff92SAndroid Build Coastguard Worker
2179*89c4ff92SAndroid Build Coastguard Worker if (handleInfo.second.m_IsInputLayerHandle)
2180*89c4ff92SAndroid Build Coastguard Worker {
2181*89c4ff92SAndroid Build Coastguard Worker inputConnectionsInfo.emplace_back(handleInfo.second.m_InputMemDescriptorCoords);
2182*89c4ff92SAndroid Build Coastguard Worker }
2183*89c4ff92SAndroid Build Coastguard Worker }
2184*89c4ff92SAndroid Build Coastguard Worker
2185*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<WorkingMemHandle>(networkId,
2186*89c4ff92SAndroid Build Coastguard Worker inputConnectionsInfo,
2187*89c4ff92SAndroid Build Coastguard Worker outputConnectionsInfo,
2188*89c4ff92SAndroid Build Coastguard Worker workingMemDescriptors,
2189*89c4ff92SAndroid Build Coastguard Worker std::move(externalMemoryManager),
2190*89c4ff92SAndroid Build Coastguard Worker std::move(tensorMemory),
2191*89c4ff92SAndroid Build Coastguard Worker std::move(managedTensorHandles),
2192*89c4ff92SAndroid Build Coastguard Worker std::move(unmanagedTensorHandles),
2193*89c4ff92SAndroid Build Coastguard Worker executionDataVec,
2194*89c4ff92SAndroid Build Coastguard Worker &m_Backends);
2195*89c4ff92SAndroid Build Coastguard Worker }
2196*89c4ff92SAndroid Build Coastguard Worker
RegisterDebugCallback(const DebugCallbackFunction & func)2197*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func)
2198*89c4ff92SAndroid Build Coastguard Worker {
2199*89c4ff92SAndroid Build Coastguard Worker for (auto&& workloadPtr: m_WorkloadQueue)
2200*89c4ff92SAndroid Build Coastguard Worker {
2201*89c4ff92SAndroid Build Coastguard Worker workloadPtr.get()->RegisterDebugCallback(func);
2202*89c4ff92SAndroid Build Coastguard Worker }
2203*89c4ff92SAndroid Build Coastguard Worker }
2204*89c4ff92SAndroid Build Coastguard Worker
2205*89c4ff92SAndroid Build Coastguard Worker
CreateMemoryProfileAsync()2206*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::CreateMemoryProfileAsync()
2207*89c4ff92SAndroid Build Coastguard Worker {
2208*89c4ff92SAndroid Build Coastguard Worker struct PartialBlock
2209*89c4ff92SAndroid Build Coastguard Worker {
2210*89c4ff92SAndroid Build Coastguard Worker unsigned int m_StartOfLife;
2211*89c4ff92SAndroid Build Coastguard Worker unsigned int m_Lifetime;
2212*89c4ff92SAndroid Build Coastguard Worker
2213*89c4ff92SAndroid Build Coastguard Worker size_t m_MemSize;
2214*89c4ff92SAndroid Build Coastguard Worker unsigned int m_Index;
2215*89c4ff92SAndroid Build Coastguard Worker
2216*89c4ff92SAndroid Build Coastguard Worker BackendId m_BackendId;
2217*89c4ff92SAndroid Build Coastguard Worker };
2218*89c4ff92SAndroid Build Coastguard Worker
2219*89c4ff92SAndroid Build Coastguard Worker auto align = [](size_t numToAlign)
2220*89c4ff92SAndroid Build Coastguard Worker {
2221*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = sizeof(float);
2222*89c4ff92SAndroid Build Coastguard Worker return ((numToAlign + alignment - 1) / alignment) * alignment;
2223*89c4ff92SAndroid Build Coastguard Worker };
2224*89c4ff92SAndroid Build Coastguard Worker
2225*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<const OutputSlot*, PartialBlock> memBlockTrackerMap;
2226*89c4ff92SAndroid Build Coastguard Worker
2227*89c4ff92SAndroid Build Coastguard Worker const bool inputImportingEnabled = m_NetworkProperties.m_InputSource != MemorySource::Undefined;
2228*89c4ff92SAndroid Build Coastguard Worker const bool outputImportingEnabled = m_NetworkProperties.m_OutputSource != MemorySource::Undefined;
2229*89c4ff92SAndroid Build Coastguard Worker
2230*89c4ff92SAndroid Build Coastguard Worker unsigned int timestep = 0;
2231*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = 0;
2232*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
2233*89c4ff92SAndroid Build Coastguard Worker
2234*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
2235*89c4ff92SAndroid Build Coastguard Worker {
2236*89c4ff92SAndroid Build Coastguard Worker const LayerType& layerType = layer->GetType();
2237*89c4ff92SAndroid Build Coastguard Worker // Don't manage memory if importing.
2238*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Input && inputImportingEnabled)
2239*89c4ff92SAndroid Build Coastguard Worker {
2240*89c4ff92SAndroid Build Coastguard Worker continue;
2241*89c4ff92SAndroid Build Coastguard Worker }
2242*89c4ff92SAndroid Build Coastguard Worker // Don't manage memory if importing.
2243*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Output && outputImportingEnabled
2244*89c4ff92SAndroid Build Coastguard Worker && layer->GetInputSlot(0).GetConnectedOutputSlot()->GetNumConnections() == 1)
2245*89c4ff92SAndroid Build Coastguard Worker {
2246*89c4ff92SAndroid Build Coastguard Worker continue;
2247*89c4ff92SAndroid Build Coastguard Worker }
2248*89c4ff92SAndroid Build Coastguard Worker // Because Constant Layer memory can not be shared, the memory must persist for the lifetime of execution,
2249*89c4ff92SAndroid Build Coastguard Worker // management is done separately.
2250*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Constant)
2251*89c4ff92SAndroid Build Coastguard Worker {
2252*89c4ff92SAndroid Build Coastguard Worker continue;
2253*89c4ff92SAndroid Build Coastguard Worker }
2254*89c4ff92SAndroid Build Coastguard Worker
2255*89c4ff92SAndroid Build Coastguard Worker BackendId backendId = layer->GetBackendId();
2256*89c4ff92SAndroid Build Coastguard Worker for (auto& outputSlot : layer->GetOutputSlots())
2257*89c4ff92SAndroid Build Coastguard Worker {
2258*89c4ff92SAndroid Build Coastguard Worker if (!m_SupportsExternallyManagedMemory[backendId])
2259*89c4ff92SAndroid Build Coastguard Worker {
2260*89c4ff92SAndroid Build Coastguard Worker continue;
2261*89c4ff92SAndroid Build Coastguard Worker }
2262*89c4ff92SAndroid Build Coastguard Worker
2263*89c4ff92SAndroid Build Coastguard Worker PartialBlock partialBlock;
2264*89c4ff92SAndroid Build Coastguard Worker
2265*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_StartOfLife = timestep;
2266*89c4ff92SAndroid Build Coastguard Worker
2267*89c4ff92SAndroid Build Coastguard Worker size_t alignedSize = align(outputSlot.GetOutputHandler().GetTensorInfo().GetNumBytes());
2268*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize = alignedSize;
2269*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index = outputIndex++;
2270*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Lifetime = outputSlot.GetNumConnections();
2271*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_BackendId = backendId;
2272*89c4ff92SAndroid Build Coastguard Worker
2273*89c4ff92SAndroid Build Coastguard Worker if (partialBlock.m_Lifetime == 0)
2274*89c4ff92SAndroid Build Coastguard Worker {
2275*89c4ff92SAndroid Build Coastguard Worker m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2276*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_StartOfLife,
2277*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize,
2278*89c4ff92SAndroid Build Coastguard Worker 0,
2279*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index);
2280*89c4ff92SAndroid Build Coastguard Worker }
2281*89c4ff92SAndroid Build Coastguard Worker else
2282*89c4ff92SAndroid Build Coastguard Worker {
2283*89c4ff92SAndroid Build Coastguard Worker memBlockTrackerMap[&outputSlot] = partialBlock;
2284*89c4ff92SAndroid Build Coastguard Worker }
2285*89c4ff92SAndroid Build Coastguard Worker }
2286*89c4ff92SAndroid Build Coastguard Worker
2287*89c4ff92SAndroid Build Coastguard Worker for (auto& inputSlot : layer->GetInputSlots())
2288*89c4ff92SAndroid Build Coastguard Worker {
2289*89c4ff92SAndroid Build Coastguard Worker const Layer& connectedInputLayer = inputSlot.GetConnectedOutputSlot()->GetOwningLayer();
2290*89c4ff92SAndroid Build Coastguard Worker const LayerType& owningLayerType = connectedInputLayer.GetType();
2291*89c4ff92SAndroid Build Coastguard Worker
2292*89c4ff92SAndroid Build Coastguard Worker if (owningLayerType == LayerType::Constant)
2293*89c4ff92SAndroid Build Coastguard Worker {
2294*89c4ff92SAndroid Build Coastguard Worker continue;
2295*89c4ff92SAndroid Build Coastguard Worker }
2296*89c4ff92SAndroid Build Coastguard Worker if (inputImportingEnabled && owningLayerType == LayerType::Input)
2297*89c4ff92SAndroid Build Coastguard Worker {
2298*89c4ff92SAndroid Build Coastguard Worker continue;
2299*89c4ff92SAndroid Build Coastguard Worker }
2300*89c4ff92SAndroid Build Coastguard Worker
2301*89c4ff92SAndroid Build Coastguard Worker auto outputSlot = inputSlot.GetConnectedOutputSlot();
2302*89c4ff92SAndroid Build Coastguard Worker
2303*89c4ff92SAndroid Build Coastguard Worker PartialBlock& partialBlock = memBlockTrackerMap.at(outputSlot);
2304*89c4ff92SAndroid Build Coastguard Worker
2305*89c4ff92SAndroid Build Coastguard Worker auto& lifetime = partialBlock.m_Lifetime;
2306*89c4ff92SAndroid Build Coastguard Worker --lifetime;
2307*89c4ff92SAndroid Build Coastguard Worker
2308*89c4ff92SAndroid Build Coastguard Worker if (lifetime == 0)
2309*89c4ff92SAndroid Build Coastguard Worker {
2310*89c4ff92SAndroid Build Coastguard Worker m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2311*89c4ff92SAndroid Build Coastguard Worker timestep,
2312*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize,
2313*89c4ff92SAndroid Build Coastguard Worker 0,
2314*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index);
2315*89c4ff92SAndroid Build Coastguard Worker }
2316*89c4ff92SAndroid Build Coastguard Worker }
2317*89c4ff92SAndroid Build Coastguard Worker ++timestep;
2318*89c4ff92SAndroid Build Coastguard Worker }
2319*89c4ff92SAndroid Build Coastguard Worker }
2320*89c4ff92SAndroid Build Coastguard Worker
CreateMemoryProfile()2321*89c4ff92SAndroid Build Coastguard Worker void LoadedNetwork::CreateMemoryProfile()
2322*89c4ff92SAndroid Build Coastguard Worker {
2323*89c4ff92SAndroid Build Coastguard Worker // Finds the first TensorHandle ancestor of a SubTensorHandle. If the ITensorHandle provided
2324*89c4ff92SAndroid Build Coastguard Worker // is a TensorHandle, the function just returns it
2325*89c4ff92SAndroid Build Coastguard Worker auto TraceSubTensorHandleAncestry = [](ITensorHandle* const subTensorHandle)
2326*89c4ff92SAndroid Build Coastguard Worker {
2327*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* ancestor = subTensorHandle;
2328*89c4ff92SAndroid Build Coastguard Worker while (ancestor && ancestor->GetParent())
2329*89c4ff92SAndroid Build Coastguard Worker {
2330*89c4ff92SAndroid Build Coastguard Worker ancestor = ancestor->GetParent();
2331*89c4ff92SAndroid Build Coastguard Worker }
2332*89c4ff92SAndroid Build Coastguard Worker return ancestor;
2333*89c4ff92SAndroid Build Coastguard Worker };
2334*89c4ff92SAndroid Build Coastguard Worker
2335*89c4ff92SAndroid Build Coastguard Worker struct PartialBlock
2336*89c4ff92SAndroid Build Coastguard Worker {
2337*89c4ff92SAndroid Build Coastguard Worker unsigned int m_StartOfLife;
2338*89c4ff92SAndroid Build Coastguard Worker unsigned int m_Lifetime;
2339*89c4ff92SAndroid Build Coastguard Worker
2340*89c4ff92SAndroid Build Coastguard Worker size_t m_MemSize;
2341*89c4ff92SAndroid Build Coastguard Worker unsigned int m_Index;
2342*89c4ff92SAndroid Build Coastguard Worker
2343*89c4ff92SAndroid Build Coastguard Worker BackendId m_BackendId;
2344*89c4ff92SAndroid Build Coastguard Worker };
2345*89c4ff92SAndroid Build Coastguard Worker
2346*89c4ff92SAndroid Build Coastguard Worker auto align = [](size_t numToAlign)
2347*89c4ff92SAndroid Build Coastguard Worker {
2348*89c4ff92SAndroid Build Coastguard Worker const size_t alignment = sizeof(float);
2349*89c4ff92SAndroid Build Coastguard Worker return ((numToAlign + alignment - 1) / alignment) * alignment;
2350*89c4ff92SAndroid Build Coastguard Worker };
2351*89c4ff92SAndroid Build Coastguard Worker
2352*89c4ff92SAndroid Build Coastguard Worker std::unordered_map<ITensorHandle*, PartialBlock> memBlockTrackerMap;
2353*89c4ff92SAndroid Build Coastguard Worker
2354*89c4ff92SAndroid Build Coastguard Worker const bool inputImportingEnabled = m_NetworkProperties.m_InputSource != MemorySource::Undefined;
2355*89c4ff92SAndroid Build Coastguard Worker const bool outputImportingEnabled = m_NetworkProperties.m_OutputSource != MemorySource::Undefined;
2356*89c4ff92SAndroid Build Coastguard Worker
2357*89c4ff92SAndroid Build Coastguard Worker unsigned int timestep = 0;
2358*89c4ff92SAndroid Build Coastguard Worker unsigned int outputIndex = 0;
2359*89c4ff92SAndroid Build Coastguard Worker Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
2360*89c4ff92SAndroid Build Coastguard Worker
2361*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : order)
2362*89c4ff92SAndroid Build Coastguard Worker {
2363*89c4ff92SAndroid Build Coastguard Worker const LayerType& layerType = layer->GetType();
2364*89c4ff92SAndroid Build Coastguard Worker // Don't manage memory if importing.
2365*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Input && inputImportingEnabled)
2366*89c4ff92SAndroid Build Coastguard Worker {
2367*89c4ff92SAndroid Build Coastguard Worker continue;
2368*89c4ff92SAndroid Build Coastguard Worker }
2369*89c4ff92SAndroid Build Coastguard Worker // Don't manage memory if importing.
2370*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Output && outputImportingEnabled
2371*89c4ff92SAndroid Build Coastguard Worker && layer->GetInputSlot(0).GetConnectedOutputSlot()->GetNumConnections() == 1)
2372*89c4ff92SAndroid Build Coastguard Worker {
2373*89c4ff92SAndroid Build Coastguard Worker continue;
2374*89c4ff92SAndroid Build Coastguard Worker }
2375*89c4ff92SAndroid Build Coastguard Worker // Because Constant Layer memory can not be shared, the memory must persist for the lifetime of execution,
2376*89c4ff92SAndroid Build Coastguard Worker // management is done separately.
2377*89c4ff92SAndroid Build Coastguard Worker if (layerType == LayerType::Constant)
2378*89c4ff92SAndroid Build Coastguard Worker {
2379*89c4ff92SAndroid Build Coastguard Worker continue;
2380*89c4ff92SAndroid Build Coastguard Worker }
2381*89c4ff92SAndroid Build Coastguard Worker
2382*89c4ff92SAndroid Build Coastguard Worker BackendId backendId = layer->GetBackendId();
2383*89c4ff92SAndroid Build Coastguard Worker for (auto& outputSlot : layer->GetOutputSlots())
2384*89c4ff92SAndroid Build Coastguard Worker {
2385*89c4ff92SAndroid Build Coastguard Worker if (!m_SupportsExternallyManagedMemory[backendId])
2386*89c4ff92SAndroid Build Coastguard Worker {
2387*89c4ff92SAndroid Build Coastguard Worker continue;
2388*89c4ff92SAndroid Build Coastguard Worker }
2389*89c4ff92SAndroid Build Coastguard Worker
2390*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = outputSlot.GetOutputHandler().GetData();
2391*89c4ff92SAndroid Build Coastguard Worker tensorHandle = TraceSubTensorHandleAncestry(tensorHandle);
2392*89c4ff92SAndroid Build Coastguard Worker
2393*89c4ff92SAndroid Build Coastguard Worker if (memBlockTrackerMap.find(tensorHandle) == memBlockTrackerMap.end())
2394*89c4ff92SAndroid Build Coastguard Worker {
2395*89c4ff92SAndroid Build Coastguard Worker PartialBlock partialBlock;
2396*89c4ff92SAndroid Build Coastguard Worker
2397*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_StartOfLife = timestep;
2398*89c4ff92SAndroid Build Coastguard Worker
2399*89c4ff92SAndroid Build Coastguard Worker size_t alignedSize = align(outputSlot.GetOutputHandler().GetTensorInfo().GetNumBytes());
2400*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize = alignedSize;
2401*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index = outputIndex++;
2402*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Lifetime = outputSlot.GetNumConnections();
2403*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_BackendId = backendId;
2404*89c4ff92SAndroid Build Coastguard Worker
2405*89c4ff92SAndroid Build Coastguard Worker if (partialBlock.m_Lifetime == 0)
2406*89c4ff92SAndroid Build Coastguard Worker {
2407*89c4ff92SAndroid Build Coastguard Worker m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2408*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_StartOfLife,
2409*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize,
2410*89c4ff92SAndroid Build Coastguard Worker 0,
2411*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index);
2412*89c4ff92SAndroid Build Coastguard Worker }
2413*89c4ff92SAndroid Build Coastguard Worker else
2414*89c4ff92SAndroid Build Coastguard Worker {
2415*89c4ff92SAndroid Build Coastguard Worker memBlockTrackerMap[tensorHandle] = partialBlock;
2416*89c4ff92SAndroid Build Coastguard Worker }
2417*89c4ff92SAndroid Build Coastguard Worker m_Tensorhandles.push_back(tensorHandle);
2418*89c4ff92SAndroid Build Coastguard Worker
2419*89c4ff92SAndroid Build Coastguard Worker }
2420*89c4ff92SAndroid Build Coastguard Worker else
2421*89c4ff92SAndroid Build Coastguard Worker {
2422*89c4ff92SAndroid Build Coastguard Worker memBlockTrackerMap.at(tensorHandle).m_Lifetime += outputSlot.GetNumConnections();
2423*89c4ff92SAndroid Build Coastguard Worker }
2424*89c4ff92SAndroid Build Coastguard Worker }
2425*89c4ff92SAndroid Build Coastguard Worker
2426*89c4ff92SAndroid Build Coastguard Worker for (auto& inputSlot : layer->GetInputSlots())
2427*89c4ff92SAndroid Build Coastguard Worker {
2428*89c4ff92SAndroid Build Coastguard Worker const Layer& connectedInputLayer = inputSlot.GetConnectedOutputSlot()->GetOwningLayer();
2429*89c4ff92SAndroid Build Coastguard Worker const LayerType& owningLayerType = connectedInputLayer.GetType();
2430*89c4ff92SAndroid Build Coastguard Worker
2431*89c4ff92SAndroid Build Coastguard Worker if (owningLayerType == LayerType::Constant)
2432*89c4ff92SAndroid Build Coastguard Worker {
2433*89c4ff92SAndroid Build Coastguard Worker continue;
2434*89c4ff92SAndroid Build Coastguard Worker }
2435*89c4ff92SAndroid Build Coastguard Worker if (inputImportingEnabled && owningLayerType == LayerType::Input)
2436*89c4ff92SAndroid Build Coastguard Worker {
2437*89c4ff92SAndroid Build Coastguard Worker continue;
2438*89c4ff92SAndroid Build Coastguard Worker }
2439*89c4ff92SAndroid Build Coastguard Worker if (!m_SupportsExternallyManagedMemory[connectedInputLayer.GetBackendId()])
2440*89c4ff92SAndroid Build Coastguard Worker {
2441*89c4ff92SAndroid Build Coastguard Worker continue;
2442*89c4ff92SAndroid Build Coastguard Worker }
2443*89c4ff92SAndroid Build Coastguard Worker
2444*89c4ff92SAndroid Build Coastguard Worker auto outputSlot = inputSlot.GetConnectedOutputSlot();
2445*89c4ff92SAndroid Build Coastguard Worker
2446*89c4ff92SAndroid Build Coastguard Worker ITensorHandle* tensorHandle = outputSlot->GetOutputHandler().GetData();
2447*89c4ff92SAndroid Build Coastguard Worker tensorHandle = TraceSubTensorHandleAncestry(tensorHandle);
2448*89c4ff92SAndroid Build Coastguard Worker
2449*89c4ff92SAndroid Build Coastguard Worker PartialBlock& partialBlock = memBlockTrackerMap.at(tensorHandle);
2450*89c4ff92SAndroid Build Coastguard Worker
2451*89c4ff92SAndroid Build Coastguard Worker auto& lifetime = partialBlock.m_Lifetime;
2452*89c4ff92SAndroid Build Coastguard Worker --lifetime;
2453*89c4ff92SAndroid Build Coastguard Worker
2454*89c4ff92SAndroid Build Coastguard Worker if (lifetime == 0)
2455*89c4ff92SAndroid Build Coastguard Worker {
2456*89c4ff92SAndroid Build Coastguard Worker m_MemBlockMap[partialBlock.m_BackendId].emplace_back(partialBlock.m_StartOfLife,
2457*89c4ff92SAndroid Build Coastguard Worker timestep,
2458*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_MemSize,
2459*89c4ff92SAndroid Build Coastguard Worker 0,
2460*89c4ff92SAndroid Build Coastguard Worker partialBlock.m_Index);
2461*89c4ff92SAndroid Build Coastguard Worker }
2462*89c4ff92SAndroid Build Coastguard Worker }
2463*89c4ff92SAndroid Build Coastguard Worker ++timestep;
2464*89c4ff92SAndroid Build Coastguard Worker }
2465*89c4ff92SAndroid Build Coastguard Worker
2466*89c4ff92SAndroid Build Coastguard Worker }
2467*89c4ff92SAndroid Build Coastguard Worker
CreateExternalMemoryManger(std::vector<std::pair<std::shared_ptr<TensorMemory>,MemorySource>> & tensorMemoryVec)2468*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MemoryManager> LoadedNetwork::CreateExternalMemoryManger(
2469*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<std::shared_ptr<TensorMemory>, MemorySource>>& tensorMemoryVec)
2470*89c4ff92SAndroid Build Coastguard Worker {
2471*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<MemoryManager> memoryManager = std::make_unique<MemoryManager>();
2472*89c4ff92SAndroid Build Coastguard Worker auto allocatorMap = BackendRegistryInstance().GetAllocators();
2473*89c4ff92SAndroid Build Coastguard Worker
2474*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : m_MemBinMap)
2475*89c4ff92SAndroid Build Coastguard Worker {
2476*89c4ff92SAndroid Build Coastguard Worker std::vector<BufferStorage> bufferStorageVec;
2477*89c4ff92SAndroid Build Coastguard Worker
2478*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<ICustomAllocator> backendAllocator;
2479*89c4ff92SAndroid Build Coastguard Worker if (allocatorMap.find(backend.first) != allocatorMap.end())
2480*89c4ff92SAndroid Build Coastguard Worker {
2481*89c4ff92SAndroid Build Coastguard Worker backendAllocator = allocatorMap[backend.first];
2482*89c4ff92SAndroid Build Coastguard Worker }
2483*89c4ff92SAndroid Build Coastguard Worker else
2484*89c4ff92SAndroid Build Coastguard Worker {
2485*89c4ff92SAndroid Build Coastguard Worker backendAllocator = m_Backends[backend.first]->GetDefaultAllocator();
2486*89c4ff92SAndroid Build Coastguard Worker }
2487*89c4ff92SAndroid Build Coastguard Worker
2488*89c4ff92SAndroid Build Coastguard Worker for (auto& memBin : backend.second)
2489*89c4ff92SAndroid Build Coastguard Worker {
2490*89c4ff92SAndroid Build Coastguard Worker BufferStorage bufferStorage;
2491*89c4ff92SAndroid Build Coastguard Worker bufferStorage.m_BufferSize = memBin.m_MemSize;
2492*89c4ff92SAndroid Build Coastguard Worker bufferStorage.m_TensorMemoryVector.reserve(memBin.m_MemBlocks.size());
2493*89c4ff92SAndroid Build Coastguard Worker
2494*89c4ff92SAndroid Build Coastguard Worker for (auto& memBlock : memBin.m_MemBlocks)
2495*89c4ff92SAndroid Build Coastguard Worker {
2496*89c4ff92SAndroid Build Coastguard Worker auto tensorMemory = std::make_shared<TensorMemory>(TensorMemory{memBlock.m_Offset, memBlock.m_Index});
2497*89c4ff92SAndroid Build Coastguard Worker
2498*89c4ff92SAndroid Build Coastguard Worker tensorMemoryVec.emplace_back(tensorMemory, backendAllocator->GetMemorySourceType());
2499*89c4ff92SAndroid Build Coastguard Worker bufferStorage.m_TensorMemoryVector.emplace_back(tensorMemory);
2500*89c4ff92SAndroid Build Coastguard Worker }
2501*89c4ff92SAndroid Build Coastguard Worker
2502*89c4ff92SAndroid Build Coastguard Worker bufferStorageVec.emplace_back(std::move(bufferStorage));
2503*89c4ff92SAndroid Build Coastguard Worker }
2504*89c4ff92SAndroid Build Coastguard Worker
2505*89c4ff92SAndroid Build Coastguard Worker memoryManager->StoreMemToAllocate(bufferStorageVec, backendAllocator, 4);
2506*89c4ff92SAndroid Build Coastguard Worker }
2507*89c4ff92SAndroid Build Coastguard Worker
2508*89c4ff92SAndroid Build Coastguard Worker return memoryManager;
2509*89c4ff92SAndroid Build Coastguard Worker }
2510*89c4ff92SAndroid Build Coastguard Worker
ValidateImportedInputID(ImportedInputId id)2511*89c4ff92SAndroid Build Coastguard Worker LayerBindingId LoadedNetwork::ValidateImportedInputID(ImportedInputId id)
2512*89c4ff92SAndroid Build Coastguard Worker {
2513*89c4ff92SAndroid Build Coastguard Worker try
2514*89c4ff92SAndroid Build Coastguard Worker {
2515*89c4ff92SAndroid Build Coastguard Worker const auto& importedTensorHandlePin = m_PreImportedInputHandles.at(id);
2516*89c4ff92SAndroid Build Coastguard Worker if (!importedTensorHandlePin.m_TensorHandle)
2517*89c4ff92SAndroid Build Coastguard Worker {
2518*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute:"
2519*89c4ff92SAndroid Build Coastguard Worker "PreImportedInput: {} has been deleted", id));
2520*89c4ff92SAndroid Build Coastguard Worker }
2521*89c4ff92SAndroid Build Coastguard Worker return importedTensorHandlePin.m_LayerBindingId;
2522*89c4ff92SAndroid Build Coastguard Worker }
2523*89c4ff92SAndroid Build Coastguard Worker catch (const std::out_of_range&)
2524*89c4ff92SAndroid Build Coastguard Worker {
2525*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: Unknown ImportedInputId: {}", id));
2526*89c4ff92SAndroid Build Coastguard Worker }
2527*89c4ff92SAndroid Build Coastguard Worker }
2528*89c4ff92SAndroid Build Coastguard Worker
ValidateImportedOutputID(ImportedOutputId id)2529*89c4ff92SAndroid Build Coastguard Worker LayerBindingId LoadedNetwork::ValidateImportedOutputID(ImportedOutputId id)
2530*89c4ff92SAndroid Build Coastguard Worker {
2531*89c4ff92SAndroid Build Coastguard Worker try
2532*89c4ff92SAndroid Build Coastguard Worker {
2533*89c4ff92SAndroid Build Coastguard Worker const auto& importedTensorHandlePin = m_PreImportedOutputHandles.at(id);
2534*89c4ff92SAndroid Build Coastguard Worker if (!importedTensorHandlePin.m_TensorHandle)
2535*89c4ff92SAndroid Build Coastguard Worker {
2536*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: "
2537*89c4ff92SAndroid Build Coastguard Worker "PreImportedOutput: {} has been deleted", id));
2538*89c4ff92SAndroid Build Coastguard Worker }
2539*89c4ff92SAndroid Build Coastguard Worker return importedTensorHandlePin.m_LayerBindingId;
2540*89c4ff92SAndroid Build Coastguard Worker }
2541*89c4ff92SAndroid Build Coastguard Worker catch (const std::out_of_range&)
2542*89c4ff92SAndroid Build Coastguard Worker {
2543*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("LoadedNetwork::Execute: Unknown ImportedOutputId: {}", id));
2544*89c4ff92SAndroid Build Coastguard Worker }
2545*89c4ff92SAndroid Build Coastguard Worker }
2546*89c4ff92SAndroid Build Coastguard Worker
2547*89c4ff92SAndroid Build Coastguard Worker }
2548