1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "MockTensorHandleFactory.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockTensorHandle.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker namespace armnn
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker using FactoryId = ITensorHandleFactory::FactoryId;
13*89c4ff92SAndroid Build Coastguard Worker
GetIdStatic()14*89c4ff92SAndroid Build Coastguard Worker const FactoryId& MockTensorHandleFactory::GetIdStatic()
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker static const FactoryId s_Id(MockTensorHandleFactoryId());
17*89c4ff92SAndroid Build Coastguard Worker return s_Id;
18*89c4ff92SAndroid Build Coastguard Worker }
19*89c4ff92SAndroid Build Coastguard Worker
CreateSubTensorHandle(ITensorHandle &,TensorShape const &,unsigned int const *) const20*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateSubTensorHandle(ITensorHandle&,
21*89c4ff92SAndroid Build Coastguard Worker TensorShape const&,
22*89c4ff92SAndroid Build Coastguard Worker unsigned int const*) const
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker return nullptr;
25*89c4ff92SAndroid Build Coastguard Worker }
26*89c4ff92SAndroid Build Coastguard Worker
CreateTensorHandle(const TensorInfo & tensorInfo) const27*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo) const
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout) const32*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
33*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout) const
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(dataLayout);
36*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged) const39*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
40*89c4ff92SAndroid Build Coastguard Worker const bool IsMemoryManaged) const
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker if (IsMemoryManaged)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker else
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_ImportFlags);
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged) const52*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> MockTensorHandleFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
53*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout,
54*89c4ff92SAndroid Build Coastguard Worker const bool IsMemoryManaged) const
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(dataLayout);
57*89c4ff92SAndroid Build Coastguard Worker if (IsMemoryManaged)
58*89c4ff92SAndroid Build Coastguard Worker {
59*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker else
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_ImportFlags);
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker
GetId() const67*89c4ff92SAndroid Build Coastguard Worker const FactoryId& MockTensorHandleFactory::GetId() const
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker return GetIdStatic();
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker
SupportsSubTensors() const72*89c4ff92SAndroid Build Coastguard Worker bool MockTensorHandleFactory::SupportsSubTensors() const
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker return false;
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker
GetExportFlags() const77*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags MockTensorHandleFactory::GetExportFlags() const
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker return m_ExportFlags;
80*89c4ff92SAndroid Build Coastguard Worker }
81*89c4ff92SAndroid Build Coastguard Worker
GetImportFlags() const82*89c4ff92SAndroid Build Coastguard Worker MemorySourceFlags MockTensorHandleFactory::GetImportFlags() const
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker return m_ImportFlags;
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
87*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn