xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/MockTensorHandleFactory.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "MockTensorHandleFactory.hpp"
7 #include <armnnTestUtils/MockTensorHandle.hpp>
8 
9 namespace armnn
10 {
11 
12 using FactoryId = ITensorHandleFactory::FactoryId;
13 
GetIdStatic()14 const FactoryId& MockTensorHandleFactory::GetIdStatic()
15 {
16     static const FactoryId s_Id(MockTensorHandleFactoryId());
17     return s_Id;
18 }
19 
CreateSubTensorHandle(ITensorHandle &,TensorShape const &,unsigned int const *) const20 std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateSubTensorHandle(ITensorHandle&,
21                                                                               TensorShape const&,
22                                                                               unsigned int const*) const
23 {
24     return nullptr;
25 }
26 
CreateTensorHandle(const TensorInfo & tensorInfo) const27 std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
28 {
29     return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
30 }
31 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const32 std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
33                                                                            DataLayout dataLayout) const
34 {
35     IgnoreUnused(dataLayout);
36     return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
37 }
38 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const39 std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
40                                                                            const bool IsMemoryManaged) const
41 {
42     if (IsMemoryManaged)
43     {
44         return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
45     }
46     else
47     {
48         return std::make_unique<MockTensorHandle>(tensorInfo, m_ImportFlags);
49     }
50 }
51 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const52 std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
53                                                                            DataLayout dataLayout,
54                                                                            const bool IsMemoryManaged) const
55 {
56     IgnoreUnused(dataLayout);
57     if (IsMemoryManaged)
58     {
59         return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
60     }
61     else
62     {
63         return std::make_unique<MockTensorHandle>(tensorInfo, m_ImportFlags);
64     }
65 }
66 
GetId() const67 const FactoryId& MockTensorHandleFactory::GetId() const
68 {
69     return GetIdStatic();
70 }
71 
SupportsSubTensors() const72 bool MockTensorHandleFactory::SupportsSubTensors() const
73 {
74     return false;
75 }
76 
GetExportFlags() const77 MemorySourceFlags MockTensorHandleFactory::GetExportFlags() const
78 {
79     return m_ExportFlags;
80 }
81 
GetImportFlags() const82 MemorySourceFlags MockTensorHandleFactory::GetImportFlags() const
83 {
84     return m_ImportFlags;
85 }
86 
87 }    // namespace armnn