1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <CommonTestUtils.hpp>
9
10 #include <armnn/INetwork.hpp>
11 #include <ResolveType.hpp>
12
13 #include <doctest/doctest.h>
14
15 namespace{
16
CreateGatherNetwork(const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<int32_t> & indicesData)17 armnn::INetworkPtr CreateGatherNetwork(const armnn::TensorInfo& paramsInfo,
18 const armnn::TensorInfo& indicesInfo,
19 const armnn::TensorInfo& outputInfo,
20 const std::vector<int32_t>& indicesData)
21 {
22 armnn::INetworkPtr net(armnn::INetwork::Create());
23
24 armnn::GatherDescriptor descriptor;
25 armnn::IConnectableLayer* paramsLayer = net->AddInputLayer(0);
26 armnn::IConnectableLayer* indicesLayer = net->AddConstantLayer(armnn::ConstTensor(indicesInfo, indicesData));
27 armnn::IConnectableLayer* gatherLayer = net->AddGatherLayer(descriptor, "gather");
28 armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
29 Connect(paramsLayer, gatherLayer, paramsInfo, 0, 0);
30 Connect(indicesLayer, gatherLayer, indicesInfo, 0, 1);
31 Connect(gatherLayer, outputLayer, outputInfo, 0, 0);
32
33 return net;
34 }
35
36 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherEndToEnd(const std::vector<BackendId> & backends)37 void GatherEndToEnd(const std::vector<BackendId>& backends)
38 {
39 armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
40 armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
41 armnn::TensorInfo outputInfo({ 3 }, ArmnnType);
42
43 paramsInfo.SetQuantizationScale(1.0f);
44 paramsInfo.SetQuantizationOffset(0);
45 paramsInfo.SetConstant(true);
46 indicesInfo.SetConstant(true);
47 outputInfo.SetQuantizationScale(1.0f);
48 outputInfo.SetQuantizationOffset(0);
49
50 // Creates structures for input & output.
51 std::vector<T> paramsData{
52 1, 2, 3, 4, 5, 6, 7, 8
53 };
54
55 std::vector<int32_t> indicesData{
56 7, 6, 5
57 };
58
59 std::vector<T> expectedOutput{
60 8, 7, 6
61 };
62
63 // Builds up the structure of the network
64 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
65
66 CHECK(net);
67
68 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
69 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
70
71 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
72 }
73
74 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
GatherMultiDimEndToEnd(const std::vector<BackendId> & backends)75 void GatherMultiDimEndToEnd(const std::vector<BackendId>& backends)
76 {
77 armnn::TensorInfo paramsInfo({ 3, 2, 3}, ArmnnType);
78 armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
79 armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
80
81 paramsInfo.SetQuantizationScale(1.0f);
82 paramsInfo.SetQuantizationOffset(0);
83 paramsInfo.SetConstant(true);
84 indicesInfo.SetConstant(true);
85 outputInfo.SetQuantizationScale(1.0f);
86 outputInfo.SetQuantizationOffset(0);
87
88 // Creates structures for input & output.
89 std::vector<T> paramsData{
90 1, 2, 3,
91 4, 5, 6,
92
93 7, 8, 9,
94 10, 11, 12,
95
96 13, 14, 15,
97 16, 17, 18
98 };
99
100 std::vector<int32_t> indicesData{
101 1, 2, 1,
102 2, 1, 0
103 };
104
105 std::vector<T> expectedOutput{
106 7, 8, 9,
107 10, 11, 12,
108 13, 14, 15,
109 16, 17, 18,
110 7, 8, 9,
111 10, 11, 12,
112
113 13, 14, 15,
114 16, 17, 18,
115 7, 8, 9,
116 10, 11, 12,
117 1, 2, 3,
118 4, 5, 6
119 };
120
121 // Builds up the structure of the network
122 armnn::INetworkPtr net = CreateGatherNetwork(paramsInfo, indicesInfo, outputInfo, indicesData);
123
124 std::map<int, std::vector<T>> inputTensorData = {{ 0, paramsData }};
125 std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
126
127 EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
128 }
129
130 } // anonymous namespace
131