xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/GatherEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <CommonTestUtils.hpp>
9 
10 #include <armnn/INetwork.hpp>
11 #include <ResolveType.hpp>
12 
13 #include <doctest/doctest.h>
14 
15 namespace{
16 
CreateGatherNetwork(const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<int32_t> & indicesData)17 armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
18                                        const armnn::TensorInfo& indicesInfo,
19                                        const armnn::TensorInfo& outputInfo,
20                                        const std::vector<int32_t>& indicesData)
21 {
22     armnn::INetworkPtr net(armnn::INetwork::Create());
23 
24     armnn::GatherDescriptor descriptor;
25     armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
26     armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
27     armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
28     armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
29     Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30     Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31     Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
32 
33     return net;
34 }
35 
36 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherEndToEnd(const std::vector<BackendId> & backends)37 void GatherEndToEnd(const std::vector<BackendId>& backends)
38 {
39     armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
40     armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
41     armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
42 
43     paramsInfo.SetQuantizationScale(1.0f);
44     paramsInfo.SetQuantizationOffset(0);
45     paramsInfo.SetConstant(true);
46     indicesInfo.SetConstant(true);
47     outputInfo.SetQuantizationScale(1.0f);
48     outputInfo.SetQuantizationOffset(0);
49 
50     // Creates structures for input & output.
51     std::vector<T> paramsData{
52         1, 2, 3, 4, 5, 6, 7, 8
53     };
54 
55     std::vector<int32_t> indicesData{
56         7, 6, 5
57     };
58 
59     std::vector<T> expectedOutput{
60         8, 7, 6
61     };
62 
63     // Builds up the structure of the network
64     armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
65 
66     CHECK(net);
67 
68     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
69     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
70 
71     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
72 }
73 
74 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherMultiDimEndToEnd(const std::vector<BackendId> & backends)75 void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
76 {
77     armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
78     armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
79     armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
80 
81     paramsInfo.SetQuantizationScale(1.0f);
82     paramsInfo.SetQuantizationOffset(0);
83     paramsInfo.SetConstant(true);
84     indicesInfo.SetConstant(true);
85     outputInfo.SetQuantizationScale(1.0f);
86     outputInfo.SetQuantizationOffset(0);
87 
88     // Creates structures for input & output.
89     std::vector<T> paramsData{
90          1,  2,  3,
91          4,  5,  6,
92 
93          7,  8,  9,
94         10, 11, 12,
95 
96         13, 14, 15,
97         16, 17, 18
98     };
99 
100     std::vector<int32_t> indicesData{
101         1, 2, 1,
102         2, 1, 0
103     };
104 
105     std::vector<T> expectedOutput{
106          7,  8,  9,
107         10, 11, 12,
108         13, 14, 15,
109         16, 17, 18,
110          7,  8,  9,
111         10, 11, 12,
112 
113         13, 14, 15,
114         16, 17, 18,
115          7,  8,  9,
116         10, 11, 12,
117          1,  2,  3,
118          4,  5,  6
119     };
120 
121     // Builds up the structure of the network
122     armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
123 
124     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
125     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
126 
127     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
128 }
129 
130 } // anonymous namespace
131