xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/MockTensorHandle.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "MockMemoryManager.hpp"
8 
9 #include <armnn/backends/TensorHandle.hpp>
10 #include <armnn/MemorySources.hpp>
11 #include <armnn/Tensor.hpp>
12 #include <armnn/Types.hpp>
13 #include <armnn/backends/ITensorHandle.hpp>
14 #include <memory>
15 
16 namespace armnn
17 {
18 
19 // An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
20 class MockTensorHandle : public ITensorHandle
21 {
22 public:
23     MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager);
24 
25     MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags);
26 
27     ~MockTensorHandle() override;
28 
29     void Manage() override;
30 
31     void Allocate() override;
32 
GetParent() const33     ITensorHandle* GetParent() const override
34     {
35         return nullptr;
36     }
37 
38     const void* Map(bool /* blocking = true */) const override;
39     using ITensorHandle::Map;
40 
Unmap() const41     void Unmap() const override
42     {}
43 
GetStrides() const44     TensorShape GetStrides() const override
45     {
46         return GetUnpaddedTensorStrides(m_TensorInfo);
47     }
48 
GetShape() const49     TensorShape GetShape() const override
50     {
51         return m_TensorInfo.GetShape();
52     }
53 
GetTensorInfo() const54     const TensorInfo& GetTensorInfo() const
55     {
56         return m_TensorInfo;
57     }
58 
GetImportFlags() const59     MemorySourceFlags GetImportFlags() const override
60     {
61         return m_ImportFlags;
62     }
63 
64     bool Import(void* memory, MemorySource source) override;
65     bool CanBeImported(void* memory, MemorySource source) override;
66 
67 private:
68     // Only used for testing
69     void CopyOutTo(void*) const override;
70     void CopyInFrom(const void*) override;
71 
72     void* GetPointer() const;
73 
74     MockTensorHandle(const MockTensorHandle& other) = delete;               // noncopyable
75     MockTensorHandle& operator=(const MockTensorHandle& other) = delete;    //noncopyable
76 
77     TensorInfo m_TensorInfo;
78 
79     std::shared_ptr<MockMemoryManager> m_MemoryManager;
80     MockMemoryManager::Pool* m_Pool;
81     mutable void* m_UnmanagedMemory;
82     MemorySourceFlags m_ImportFlags;
83     bool m_Imported;
84     bool m_IsImportEnabled;
85 };
86 
87 }    // namespace armnn
88