1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include <aclCommon/BaseMemoryManager.hpp> 8 #include <armnn/MemorySources.hpp> 9 #include <armnn/backends/IMemoryManager.hpp> 10 #include <armnn/backends/ITensorHandleFactory.hpp> 11 12 namespace armnn 13 { 14 ClTensorHandleFactoryId()15constexpr const char* ClTensorHandleFactoryId() 16 { 17 return "Arm/Cl/TensorHandleFactory"; 18 } 19 20 class ClTensorHandleFactory : public ITensorHandleFactory 21 { 22 public: 23 static const FactoryId m_Id; 24 ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr)25 ClTensorHandleFactory(std::shared_ptr<ClMemoryManager> mgr) 26 : m_MemoryManager(mgr) 27 {} 28 29 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 30 const TensorShape& subTensorShape, 31 const unsigned int* subTensorOrigin) const override; 32 33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; 34 35 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 36 DataLayout dataLayout) const override; 37 38 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 39 const bool IsMemoryManaged) const override; 40 41 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 42 DataLayout dataLayout, 43 const bool IsMemoryManaged) const override; 44 45 static const FactoryId& GetIdStatic(); 46 47 const FactoryId& GetId() const override; 48 49 bool SupportsSubTensors() const override; 50 51 MemorySourceFlags GetExportFlags() const override; 52 53 MemorySourceFlags GetImportFlags() const override; 54 55 private: 56 mutable std::shared_ptr<ClMemoryManager> m_MemoryManager; 57 }; 58 59 } // namespace armnn 60