xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefMemCopyTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefWorkloadFactory.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <reference/RefBackend.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/LayerTestResult.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MemCopyTestImpl.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker namespace
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker template <>
17*89c4ff92SAndroid Build Coastguard Worker struct MemCopyTestHelper<armnn::RefWorkloadFactory>
18*89c4ff92SAndroid Build Coastguard Worker {
GetMemoryManager__anonf6c9f3bb0111::MemCopyTestHelper19*89c4ff92SAndroid Build Coastguard Worker     static armnn::IBackendInternal::IMemoryManagerSharedPtr GetMemoryManager()
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         armnn::RefBackend backend;
22*89c4ff92SAndroid Build Coastguard Worker         return backend.CreateMemoryManager();
23*89c4ff92SAndroid Build Coastguard Worker     }
24*89c4ff92SAndroid Build Coastguard Worker 
GetFactory__anonf6c9f3bb0111::MemCopyTestHelper25*89c4ff92SAndroid Build Coastguard Worker     static armnn::RefWorkloadFactory GetFactory(const armnn::IBackendInternal::IMemoryManagerSharedPtr&)
26*89c4ff92SAndroid Build Coastguard Worker     {
27*89c4ff92SAndroid Build Coastguard Worker         return armnn::RefWorkloadFactory();
28*89c4ff92SAndroid Build Coastguard Worker     }
29*89c4ff92SAndroid Build Coastguard Worker };
30*89c4ff92SAndroid Build Coastguard Worker }    // namespace
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("RefMemCopy")
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     TEST_CASE("CopyBetweenMockAccAndRef")
36*89c4ff92SAndroid Build Coastguard Worker     {
37*89c4ff92SAndroid Build Coastguard Worker         LayerTestResult<float, 4> result =
38*89c4ff92SAndroid Build Coastguard Worker             MemCopyTest<armnn::MockWorkloadFactory, armnn::RefWorkloadFactory, armnn::DataType::Float32>(false);
39*89c4ff92SAndroid Build Coastguard Worker         auto predResult =
40*89c4ff92SAndroid Build Coastguard Worker             CompareTensors(result.m_ActualData, result.m_ExpectedData, result.m_ActualShape, result.m_ExpectedShape);
41*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(predResult.m_Result, predResult.m_Message.str());
42*89c4ff92SAndroid Build Coastguard Worker     }
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     TEST_CASE("CopyBetweenRefAndMockAcc")
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         LayerTestResult<float, 4> result =
47*89c4ff92SAndroid Build Coastguard Worker             MemCopyTest<armnn::RefWorkloadFactory, armnn::MockWorkloadFactory, armnn::DataType::Float32>(false);
48*89c4ff92SAndroid Build Coastguard Worker         auto predResult =
49*89c4ff92SAndroid Build Coastguard Worker             CompareTensors(result.m_ActualData, result.m_ExpectedData, result.m_ActualShape, result.m_ExpectedShape);
50*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(predResult.m_Result, predResult.m_Message.str());
51*89c4ff92SAndroid Build Coastguard Worker     }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     TEST_CASE("CopyBetweenMockAccAndRefWithSubtensors")
54*89c4ff92SAndroid Build Coastguard Worker     {
55*89c4ff92SAndroid Build Coastguard Worker         LayerTestResult<float, 4> result =
56*89c4ff92SAndroid Build Coastguard Worker             MemCopyTest<armnn::MockWorkloadFactory, armnn::RefWorkloadFactory, armnn::DataType::Float32>(true);
57*89c4ff92SAndroid Build Coastguard Worker         auto predResult =
58*89c4ff92SAndroid Build Coastguard Worker             CompareTensors(result.m_ActualData, result.m_ExpectedData, result.m_ActualShape, result.m_ExpectedShape);
59*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(predResult.m_Result, predResult.m_Message.str());
60*89c4ff92SAndroid Build Coastguard Worker     }
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     TEST_CASE("CopyBetweenRefAndMockAccWithSubtensors")
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         LayerTestResult<float, 4> result =
65*89c4ff92SAndroid Build Coastguard Worker             MemCopyTest<armnn::RefWorkloadFactory, armnn::MockWorkloadFactory, armnn::DataType::Float32>(true);
66*89c4ff92SAndroid Build Coastguard Worker         auto predResult =
67*89c4ff92SAndroid Build Coastguard Worker             CompareTensors(result.m_ActualData, result.m_ExpectedData, result.m_ActualShape, result.m_ExpectedShape);
68*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(predResult.m_Result, predResult.m_Message.str());
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker }
71