xref: /aosp_15_r20/external/armnn/src/armnn/OutputHandler.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/backends/ITensorHandle.hpp>
8 #include <armnn/backends/ITensorHandleFactory.hpp>
9 
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/INetwork.hpp>
12 #include <armnn/Tensor.hpp>
13 #include <armnn/Types.hpp>
14 
15 #include <memory>
16 #include <set>
17 #include <string>
18 #include <vector>
19 
20 namespace armnn
21 {
22 
23 class ITensorHandle;
24 class IWorkloadFactory;
25 class OutputSlot;
26 class WorkloadDataCollector;
27 
28 class OutputHandler
29 {
30 public:
31     /// @brief - Sets the TensorInfo used by this output handler.
32     /// @param tensorInfo - TensorInfo for the output.
33     void SetTensorInfo(const TensorInfo& tensorInfo);
34 
35     /// @brief - Creates tensor handles used by the intermediate tensors. Does not allocate memory.
36     /// @param factory - Factory to be used for handler creation.
37     void CreateTensorHandles(const IWorkloadFactory& factory, const bool IsMemoryManaged = true);
38     void CreateTensorHandles(const ITensorHandleFactory& factory, const bool IsMemoryManaged = true);
39 
40     /// @brief - Gets the matching TensorInfo for the output.
41     /// @return - References to the output TensorInfo.
GetTensorInfo() const42     const TensorInfo& GetTensorInfo() const { return m_TensorInfo; }
43 
44     /// @brief - Gets the allocated tensor memory.
45     /// @return - Pointer to the tensor memory.
GetData() const46     ITensorHandle* GetData() const { return m_TensorHandle.get(); }
47 
48     /// Fill the outputs for a given queue descriptor.
49     void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const;
50 
SetData(std::unique_ptr<ITensorHandle> data)51     void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); }
52 
53     void SetAllocatedData();
54 
UseAllocatedData()55     void UseAllocatedData() { m_TensorHandle = m_AllocatedTensorHandle; }
56 
57     /// @brief Returns true if SetTensorInfo() has been called at least once on this.
IsTensorInfoSet() const58     bool IsTensorInfoSet() const { return m_bTensorInfoSet; }
59 private:
60     std::shared_ptr<ITensorHandle> m_TensorHandle;
61     std::shared_ptr<ITensorHandle> m_AllocatedTensorHandle;
62     TensorInfo m_TensorInfo;
63     bool m_bTensorInfoSet = false;
64 };
65 
66 } //namespace armnn
67