1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <armnn/backends/ITensorHandle.hpp> 8 #include <armnn/backends/ITensorHandleFactory.hpp> 9 10 #include <armnn/Descriptors.hpp> 11 #include <armnn/INetwork.hpp> 12 #include <armnn/Tensor.hpp> 13 #include <armnn/Types.hpp> 14 15 #include <memory> 16 #include <set> 17 #include <string> 18 #include <vector> 19 20 namespace armnn 21 { 22 23 class ITensorHandle; 24 class IWorkloadFactory; 25 class OutputSlot; 26 class WorkloadDataCollector; 27 28 class OutputHandler 29 { 30 public: 31 /// @brief - Sets the TensorInfo used by this output handler. 32 /// @param tensorInfo - TensorInfo for the output. 33 void SetTensorInfo(const TensorInfo& tensorInfo); 34 35 /// @brief - Creates tensor handles used by the intermediate tensors. Does not allocate memory. 36 /// @param factory - Factory to be used for handler creation. 37 void CreateTensorHandles(const IWorkloadFactory& factory, const bool IsMemoryManaged = true); 38 void CreateTensorHandles(const ITensorHandleFactory& factory, const bool IsMemoryManaged = true); 39 40 /// @brief - Gets the matching TensorInfo for the output. 41 /// @return - References to the output TensorInfo. GetTensorInfo() const42 const TensorInfo& GetTensorInfo() const { return m_TensorInfo; } 43 44 /// @brief - Gets the allocated tensor memory. 45 /// @return - Pointer to the tensor memory. GetData() const46 ITensorHandle* GetData() const { return m_TensorHandle.get(); } 47 48 /// Fill the outputs for a given queue descriptor. 49 void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const; 50 SetData(std::unique_ptr<ITensorHandle> data)51 void SetData(std::unique_ptr<ITensorHandle> data) { m_TensorHandle = std::move(data); } 52 53 void SetAllocatedData(); 54 UseAllocatedData()55 void UseAllocatedData() { m_TensorHandle = m_AllocatedTensorHandle; } 56 57 /// @brief Returns true if SetTensorInfo() has been called at least once on this. IsTensorInfoSet() const58 bool IsTensorInfoSet() const { return m_bTensorInfoSet; } 59 private: 60 std::shared_ptr<ITensorHandle> m_TensorHandle; 61 std::shared_ptr<ITensorHandle> m_AllocatedTensorHandle; 62 TensorInfo m_TensorInfo; 63 bool m_bTensorInfoSet = false; 64 }; 65 66 } //namespace armnn 67