xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/MemCopyTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "LayerTestResult.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "TensorCopyUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "TensorHelpers.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "WorkloadTestUtils.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType dataType, typename T = armnn::ResolveType<dataType>>
MemCopyTest(armnn::IWorkloadFactory & srcWorkloadFactory,armnn::IWorkloadFactory & dstWorkloadFactory,bool withSubtensors)19*89c4ff92SAndroid Build Coastguard Worker LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory,
20*89c4ff92SAndroid Build Coastguard Worker                                   armnn::IWorkloadFactory& dstWorkloadFactory,
21*89c4ff92SAndroid Build Coastguard Worker                                   bool withSubtensors)
22*89c4ff92SAndroid Build Coastguard Worker {
23*89c4ff92SAndroid Build Coastguard Worker     const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } };
24*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape tensorShape(4, shapeData.data());
25*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo tensorInfo(tensorShape, dataType);
26*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> inputData =
27*89c4ff92SAndroid Build Coastguard Worker     {
28*89c4ff92SAndroid Build Coastguard Worker          1,  2,  3,  4,  5,
29*89c4ff92SAndroid Build Coastguard Worker          6,  7,  8,  9, 10,
30*89c4ff92SAndroid Build Coastguard Worker         11, 12, 13, 14, 15,
31*89c4ff92SAndroid Build Coastguard Worker         16, 17, 18, 19, 20,
32*89c4ff92SAndroid Build Coastguard Worker         21, 22, 23, 24, 25,
33*89c4ff92SAndroid Build Coastguard Worker         26, 27, 28, 29, 30,
34*89c4ff92SAndroid Build Coastguard Worker     };
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     LayerTestResult<T, 4> ret(tensorInfo);
37*89c4ff92SAndroid Build Coastguard Worker     ret.m_ExpectedData = inputData;
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> actualOutput(tensorInfo.GetNumElements());
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
42*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo);
43*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorHandle = dstWorkloadFactory.CreateTensorHandle(tensorInfo);
44*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     AllocateAndCopyDataToITensorHandle(inputTensorHandle.get(), inputData.data());
47*89c4ff92SAndroid Build Coastguard Worker     outputTensorHandle->Allocate();
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     armnn::MemCopyQueueDescriptor memCopyQueueDesc;
50*89c4ff92SAndroid Build Coastguard Worker     armnn::WorkloadInfo workloadInfo;
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     const unsigned int origin[4] = {};
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
55*89c4ff92SAndroid Build Coastguard Worker     auto workloadInput  = (withSubtensors && srcWorkloadFactory.SupportsSubTensors())
56*89c4ff92SAndroid Build Coastguard Worker                               ? srcWorkloadFactory.CreateSubTensorHandle(*inputTensorHandle, tensorShape, origin)
57*89c4ff92SAndroid Build Coastguard Worker                               : std::move(inputTensorHandle);
58*89c4ff92SAndroid Build Coastguard Worker     auto workloadOutput = (withSubtensors && dstWorkloadFactory.SupportsSubTensors())
59*89c4ff92SAndroid Build Coastguard Worker                               ? dstWorkloadFactory.CreateSubTensorHandle(*outputTensorHandle, tensorShape, origin)
60*89c4ff92SAndroid Build Coastguard Worker                               : std::move(outputTensorHandle);
61*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     AddInputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadInput.get());
64*89c4ff92SAndroid Build Coastguard Worker     AddOutputToWorkload(memCopyQueueDesc, workloadInfo, tensorInfo, workloadOutput.get());
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     dstWorkloadFactory.CreateWorkload(armnn::LayerType::MemCopy, memCopyQueueDesc, workloadInfo)->Execute();
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get());
69*89c4ff92SAndroid Build Coastguard Worker     ret.m_ActualData = actualOutput;
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     return ret;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker template <typename WorkloadFactoryType>
75*89c4ff92SAndroid Build Coastguard Worker struct MemCopyTestHelper
76*89c4ff92SAndroid Build Coastguard Worker {};
77*89c4ff92SAndroid Build Coastguard Worker template <>
78*89c4ff92SAndroid Build Coastguard Worker struct MemCopyTestHelper<armnn::MockWorkloadFactory>
79*89c4ff92SAndroid Build Coastguard Worker {
GetMemoryManager__anon845caf830111::MemCopyTestHelper80*89c4ff92SAndroid Build Coastguard Worker     static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
81*89c4ff92SAndroid Build Coastguard Worker     {
82*89c4ff92SAndroid Build Coastguard Worker         armnn::MockBackend backend;
83*89c4ff92SAndroid Build Coastguard Worker         return backend.CreateMemoryManager();
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     static armnn::MockWorkloadFactory
GetFactory__anon845caf830111::MemCopyTestHelper87*89c4ff92SAndroid Build Coastguard Worker         GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr)
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(memoryManager);
90*89c4ff92SAndroid Build Coastguard Worker         return armnn::MockWorkloadFactory();
91*89c4ff92SAndroid Build Coastguard Worker     }
92*89c4ff92SAndroid Build Coastguard Worker };
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker using MockMemCopyTestHelper = MemCopyTestHelper<armnn::MockWorkloadFactory>;
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker template <typename SrcWorkloadFactory,
97*89c4ff92SAndroid Build Coastguard Worker           typename DstWorkloadFactory,
98*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType dataType,
99*89c4ff92SAndroid Build Coastguard Worker           typename T = armnn::ResolveType<dataType>>
MemCopyTest(bool withSubtensors)100*89c4ff92SAndroid Build Coastguard Worker LayerTestResult<T, 4> MemCopyTest(bool withSubtensors)
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker     armnn::IBackendInternal::IMemoryManagerSharedPtr srcMemoryManager =
104*89c4ff92SAndroid Build Coastguard Worker         MemCopyTestHelper<SrcWorkloadFactory>::GetMemoryManager();
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     armnn::IBackendInternal::IMemoryManagerSharedPtr dstMemoryManager =
107*89c4ff92SAndroid Build Coastguard Worker         MemCopyTestHelper<DstWorkloadFactory>::GetMemoryManager();
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     SrcWorkloadFactory srcWorkloadFactory = MemCopyTestHelper<SrcWorkloadFactory>::GetFactory(srcMemoryManager);
110*89c4ff92SAndroid Build Coastguard Worker     DstWorkloadFactory dstWorkloadFactory = MemCopyTestHelper<DstWorkloadFactory>::GetFactory(dstMemoryManager);
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     return MemCopyTest<dataType>(srcWorkloadFactory, dstWorkloadFactory, withSubtensors);
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker }    // anonymous namespace
116