xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeGather.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_Gather")
12 {
13 struct GatherFixture : public ParserFlatbuffersSerializeFixture
14 {
GatherFixtureGatherFixture15     explicit GatherFixture(const std::string& inputShape,
16                            const std::string& indicesShape,
17                            const std::string& input1Content,
18                            const std::string& outputShape,
19                            const std::string& axis,
20                            const std::string dataType,
21                            const std::string constDataType)
22     {
23         m_JsonString = R"(
24         {
25                 inputIds: [0],
26                 outputIds: [3],
27                 layers: [
28                 {
29                     layer_type: "InputLayer",
30                     layer: {
31                           base: {
32                                 layerBindingId: 0,
33                                 base: {
34                                     index: 0,
35                                     layerName: "InputLayer",
36                                     layerType: "Input",
37                                     inputSlots: [{
38                                         index: 0,
39                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
40                                     }],
41                                     outputSlots: [ {
42                                         index: 0,
43                                         tensorInfo: {
44                                             dimensions: )" + inputShape + R"(,
45                                             dataType: )" + dataType + R"(
46                                             }}]
47                                     }
48                     }}},
49                     {
50                     layer_type: "ConstantLayer",
51                         layer: {
52                                base: {
53                                   index:1,
54                                   layerName: "ConstantLayer",
55                                   layerType: "Constant",
56                                    outputSlots: [ {
57                                     index: 0,
58                                     tensorInfo: {
59                                         dimensions: )" + indicesShape + R"(,
60                                         dataType: "Signed32",
61                                     },
62                                   }],
63                               },
64                               input: {
65                               info: {
66                                        dimensions: )" + indicesShape + R"(,
67                                        dataType: )" + dataType + R"(
68                                    },
69                               data_type: )" + constDataType + R"(,
70                               data: {
71                                   data: )" + input1Content + R"(,
72                                     } }
73                                 },},
74                     {
75                     layer_type: "GatherLayer",
76                         layer: {
77                               base: {
78                                    index: 2,
79                                    layerName: "GatherLayer",
80                                    layerType: "Gather",
81                                    inputSlots: [
82                                    {
83                                        index: 0,
84                                        connection: {sourceLayerIndex:0, outputSlotIndex:0 },
85                                    },
86                                    {
87                                         index: 1,
88                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 }
89                                    }],
90                                    outputSlots: [ {
91                                           index: 0,
92                                           tensorInfo: {
93                                                dimensions: )" + outputShape + R"(,
94                                                dataType: )" + dataType + R"(
95 
96                                    }}]},
97                                    descriptor: {
98                                        axis: )" + axis + R"(
99                                    }
100                         }},
101                     {
102                     layer_type: "OutputLayer",
103                     layer: {
104                         base:{
105                               layerBindingId: 0,
106                               base: {
107                                     index: 3,
108                                     layerName: "OutputLayer",
109                                     layerType: "Output",
110                                     inputSlots: [{
111                                         index: 0,
112                                         connection: {sourceLayerIndex:2, outputSlotIndex:0 },
113                                     }],
114                                     outputSlots: [ {
115                                         index: 0,
116                                         tensorInfo: {
117                                             dimensions: )" + outputShape + R"(,
118                                             dataType: )" + dataType + R"(
119                                         },
120                                 }],
121                             }}},
122                 }],
123                 featureVersions: {
124                     weightsLayoutScheme: 1,
125                 }
126                  } )";
127 
128         Setup();
129     }
130 };
131 
132 struct SimpleGatherFixtureFloat32 : GatherFixture
133 {
SimpleGatherFixtureFloat32SimpleGatherFixtureFloat32134     SimpleGatherFixtureFloat32() : GatherFixture("[ 3, 2, 3 ]", "[ 2, 3 ]", "[1, 2, 1, 2, 1, 0]",
135                                                  "[ 2, 3, 2, 3 ]", "0", "Float32", "IntData") {}
136 };
137 
138 TEST_CASE_FIXTURE(SimpleGatherFixtureFloat32, "GatherFloat32")
139 {
140     RunTest<4, armnn::DataType::Float32>(0,
141                                          {{"InputLayer", {  1,  2,  3,
142                                                             4,  5,  6,
143                                                             7,  8,  9,
144                                                             10, 11, 12,
145                                                             13, 14, 15,
146                                                             16, 17, 18 }}},
147                                          {{"OutputLayer", { 7,  8,  9,
148                                                             10, 11, 12,
149                                                             13, 14, 15,
150                                                             16, 17, 18,
151                                                             7,  8,  9,
152                                                             10, 11, 12,
153                                                             13, 14, 15,
154                                                             16, 17, 18,
155                                                             7,  8,  9,
156                                                             10, 11, 12,
157                                                             1,  2,  3,
158                                                             4,  5,  6 }}});
159 }
160 
161 }
162 
163