1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "LayerTestResult.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "TensorCopyUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "TensorHelpers.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "WorkloadTestUtils.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)19*89c4ff92SAndroid Build Coastguard Worker LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
20*89c4ff92SAndroid Build Coastguard Worker armnn::IWorkloadFactory& dstWorkloadFactory,
21*89c4ff92SAndroid Build Coastguard Worker bool withSubtensors)
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
24*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape tensorShape(4, shapeData.data());
25*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo tensorInfo(tensorShape, dataType);
26*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData =
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker 1, 2, 3, 4, 5,
29*89c4ff92SAndroid Build Coastguard Worker 6, 7, 8, 9, 10,
30*89c4ff92SAndroid Build Coastguard Worker 11, 12, 13, 14, 15,
31*89c4ff92SAndroid Build Coastguard Worker 16, 17, 18, 19, 20,
32*89c4ff92SAndroid Build Coastguard Worker 21, 22, 23, 24, 25,
33*89c4ff92SAndroid Build Coastguard Worker 26, 27, 28, 29, 30,
34*89c4ff92SAndroid Build Coastguard Worker };
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker LayerTestResult<T, 4> ret(tensorInfo);
37*89c4ff92SAndroid Build Coastguard Worker ret.m_ExpectedData = inputData;
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker std::vector<T> actualOutput(tensorInfo.GetNumElements());
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
42*89c4ff92SAndroid Build Coastguard Worker auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
43*89c4ff92SAndroid Build Coastguard Worker auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
44*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
45*89c4ff92SAndroid Build Coastguard Worker
46*89c4ff92SAndroid Build Coastguard Worker AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
47*89c4ff92SAndroid Build Coastguard Worker outputTensorHandle->Allocate();
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker armnn::MemCopyQueueDescriptor memCopyQueueDesc;
50*89c4ff92SAndroid Build Coastguard Worker armnn::WorkloadInfo workloadInfo;
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker const unsigned int origin[4] = {};
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
55*89c4ff92SAndroid Build Coastguard Worker auto workloadInput = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
56*89c4ff92SAndroid Build Coastguard Worker ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
57*89c4ff92SAndroid Build Coastguard Worker : std::move(inputTensorHandle);
58*89c4ff92SAndroid Build Coastguard Worker auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
59*89c4ff92SAndroid Build Coastguard Worker ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
60*89c4ff92SAndroid Build Coastguard Worker : std::move(outputTensorHandle);
61*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
64*89c4ff92SAndroid Build Coastguard Worker AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
69*89c4ff92SAndroid Build Coastguard Worker ret.m_ActualData = actualOutput;
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker return ret;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker template <typename WorkloadFactoryType>
75*89c4ff92SAndroid Build Coastguard Worker struct MemCopyTestHelper
76*89c4ff92SAndroid Build Coastguard Worker {};
77*89c4ff92SAndroid Build Coastguard Worker template <>
78*89c4ff92SAndroid Build Coastguard Worker struct MemCopyTestHelper<armnn::MockWorkloadFactory>
79*89c4ff92SAndroid Build Coastguard Worker {
GetMemoryManager__anon845caf830111::MemCopyTestHelper80*89c4ff92SAndroid Build Coastguard Worker static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker armnn::MockBackend backend;
83*89c4ff92SAndroid Build Coastguard Worker return backend.CreateMemoryManager();
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker static armnn::MockWorkloadFactory
GetFactory__anon845caf830111::MemCopyTestHelper87*89c4ff92SAndroid Build Coastguard Worker GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(memoryManager);
90*89c4ff92SAndroid Build Coastguard Worker return armnn::MockWorkloadFactory();
91*89c4ff92SAndroid Build Coastguard Worker }
92*89c4ff92SAndroid Build Coastguard Worker };
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker using MockMemCopyTestHelper = MemCopyTestHelper<armnn::MockWorkloadFactory>;
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker template <typename SrcWorkloadFactory,
97*89c4ff92SAndroid Build Coastguard Worker typename DstWorkloadFactory,
98*89c4ff92SAndroid Build Coastguard Worker armnn::DataType dataType,
99*89c4ff92SAndroid Build Coastguard Worker typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)100*89c4ff92SAndroid Build Coastguard Worker LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker
103*89c4ff92SAndroid Build Coastguard Worker armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
104*89c4ff92SAndroid Build Coastguard Worker MemCopyTestHelper<SrcWorkloadFactory>::GetMemoryManager();
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
107*89c4ff92SAndroid Build Coastguard Worker MemCopyTestHelper<DstWorkloadFactory>::GetMemoryManager();
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker SrcWorkloadFactory srcWorkloadFactory = MemCopyTestHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
110*89c4ff92SAndroid Build Coastguard Worker DstWorkloadFactory dstWorkloadFactory = MemCopyTestHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
116