1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "MockMemoryManager.hpp" 8 9 #include <armnn/backends/TensorHandle.hpp> 10 #include <armnn/MemorySources.hpp> 11 #include <armnn/Tensor.hpp> 12 #include <armnn/Types.hpp> 13 #include <armnn/backends/ITensorHandle.hpp> 14 #include <memory> 15 16 namespace armnn 17 { 18 19 // An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour 20 class MockTensorHandle : public ITensorHandle 21 { 22 public: 23 MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager); 24 25 MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags); 26 27 ~MockTensorHandle() override; 28 29 void Manage() override; 30 31 void Allocate() override; 32 GetParent() const33 ITensorHandle* GetParent() const override 34 { 35 return nullptr; 36 } 37 38 const void* Map(bool /* blocking = true */) const override; 39 using ITensorHandle::Map; 40 Unmap() const41 void Unmap() const override 42 {} 43 GetStrides() const44 TensorShape GetStrides() const override 45 { 46 return GetUnpaddedTensorStrides(m_TensorInfo); 47 } 48 GetShape() const49 TensorShape GetShape() const override 50 { 51 return m_TensorInfo.GetShape(); 52 } 53 GetTensorInfo() const54 const TensorInfo& GetTensorInfo() const 55 { 56 return m_TensorInfo; 57 } 58 GetImportFlags() const59 MemorySourceFlags GetImportFlags() const override 60 { 61 return m_ImportFlags; 62 } 63 64 bool Import(void* memory, MemorySource source) override; 65 bool CanBeImported(void* memory, MemorySource source) override; 66 67 private: 68 // Only used for testing 69 void CopyOutTo(void*) const override; 70 void CopyInFrom(const void*) override; 71 72 void* GetPointer() const; 73 74 MockTensorHandle(const MockTensorHandle& other) = delete; // noncopyable 75 MockTensorHandle& operator=(const MockTensorHandle& other) = delete; //noncopyable 76 77 TensorInfo m_TensorInfo; 78 79 std::shared_ptr<MockMemoryManager> m_MemoryManager; 80 MockMemoryManager::Pool* m_Pool; 81 mutable void* m_UnmanagedMemory; 82 MemorySourceFlags m_ImportFlags; 83 bool m_Imported; 84 bool m_IsImportEnabled; 85 }; 86 87 } // namespace armnn 88