xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/GatherNdEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <CommonTestUtils.hpp>
9 
10 #include <armnn/INetwork.hpp>
11 #include <ResolveType.hpp>
12 
13 #include <doctest/doctest.h>
14 
15 namespace{
16 
CreateGatherNdNetwork(const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<int32_t> & indicesData)17 armnn::INetworkPtr CreateGatherNdNetwork(const armnn::TensorInfo& paramsInfo,
18                                          const armnn::TensorInfo& indicesInfo,
19                                          const armnn::TensorInfo& outputInfo,
20                                          const std::vector<int32_t>& indicesData)
21 {
22     armnn::INetworkPtr net(armnn::INetwork::Create());
23 
24     armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
25     armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
26     armnn::IConnectableLayer* gatherNdLayer = net->AddGatherNdLayer("gatherNd");
27     armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
28     Connect(paramsLayer, gatherNdLayer, paramsInfo, 0, 0);
29     Connect(indicesLayer, gatherNdLayer, indicesInfo, 0, 1);
30     Connect(gatherNdLayer, outputLayer, outputInfo, 0, 0);
31 
32     return net;
33 }
34 
35 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherNdEndToEnd(const std::vector<BackendId> & backends)36 void GatherNdEndToEnd(const std::vector<BackendId>& backends)
37 {
38     armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
39     armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
40     armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
41 
42     paramsInfo.SetQuantizationScale(1.0f);
43     paramsInfo.SetQuantizationOffset(0);
44     paramsInfo.SetConstant(true);
45     indicesInfo.SetConstant(true);
46     outputInfo.SetQuantizationScale(1.0f);
47     outputInfo.SetQuantizationOffset(0);
48 
49     // Creates structures for input & output.
50     std::vector<T> paramsData{
51              0,   1,   2,   3, 4,   5,   6,   7, 8,   9,  10,  11, 12,  13,  14,  15,
52             16,  17,  18,  19, 20,  21,  22,  23, 24,  25,  26,  27, 28,  29,  30,  31,
53 
54             32,  33,  34,  35, 36,  37,  38,  39, 40,  41,  42,  43, 44,  45,  46,  47,
55             48,  49,  50,  51, 52,  53,  54,  55, 56,  57,  58,  59, 60,  61,  62,  63,
56 
57             64,  65,  66,  67, 68,  69,  70,  71, 72,  73,  74,  75, 76,  77,  78,  79,
58             80,  81,  82,  83, 84,  85,  86,  87, 88,  89,  90,  91, 92,  93,  94,  95,
59 
60             96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
61             112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
62 
63             128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
64             144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
65 
66             160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
67             176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191
68     };
69 
70     std::vector<int32_t> indicesData{
71             { 1, 2, 1, 1},
72     };
73 
74     std::vector<T> expectedOutput{
75         160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
76         176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
77 
78         128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
79         144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159
80     };
81 
82     // Builds up the structure of the network
83     armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
84 
85     CHECK(net);
86 
87     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
88     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
89 
90     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
91 }
92 
93 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherNdMultiDimEndToEnd(const std::vector<BackendId> & backends)94 void GatherNdMultiDimEndToEnd(const std::vector<BackendId>& backends)
95 {
96     armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
97     armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
98     armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
99 
100     paramsInfo.SetQuantizationScale(1.0f);
101     paramsInfo.SetQuantizationOffset(0);
102     paramsInfo.SetConstant(true);
103     indicesInfo.SetConstant(true);
104     outputInfo.SetQuantizationScale(1.0f);
105     outputInfo.SetQuantizationOffset(0);
106 
107     // Creates structures for input & output.
108     std::vector<T> paramsData{
109             0,  1,    2,  3,    4,  5,    6,  7,    8,  9,
110             10, 11,   12,  13,   14, 15,   16, 17,   18, 19,
111             20, 21,   22,  23,   24, 25,   26, 27,   28, 29,
112             30, 31,   32,  33,   34, 35,   36, 37,   38, 39,
113             40, 41,   42,  43,   44, 45,   46, 47,   48, 49
114     };
115 
116     std::vector<int32_t> indicesData{
117             0, 0,
118             3, 3,
119             4, 4,
120 
121             0, 0,
122             1, 1,
123             2, 2,
124 
125             4, 4,
126             3, 3,
127             0, 0,
128 
129             2, 2,
130             1, 1,
131             0, 0
132     };
133 
134     std::vector<T> expectedOutput{
135             0,  1,
136             36, 37,
137             48, 49,
138 
139             0,  1,
140             12, 13,
141             24, 25,
142 
143             48, 49,
144             36, 37,
145             0,  1,
146 
147             24, 25,
148             12, 13,
149             0,  1
150     };
151 
152     // Builds up the structure of the network
153     armnn::INetworkPtr net = CreateGatherNdNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
154 
155     std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
156     std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
157 
158     EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
159 }
160 
161 } // anonymous namespace
162