xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeMean.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include <armnnDeserializer/IDeserializer.hpp>
8 
9 #include <string>
10 
11 TEST_SUITE("Deserializer_Mean")
12 {
13 struct MeanFixture : public ParserFlatbuffersSerializeFixture
14 {
MeanFixtureMeanFixture15     explicit MeanFixture(const std::string &inputShape,
16                          const std::string &outputShape,
17                          const std::string &axis,
18                          const std::string &dataType)
19     {
20         m_JsonString = R"(
21             {
22                 inputIds: [0],
23                 outputIds: [2],
24                 layers: [
25                     {
26                         layer_type: "InputLayer",
27                         layer: {
28                             base: {
29                                 layerBindingId: 0,
30                                 base: {
31                                     index: 0,
32                                     layerName: "InputLayer",
33                                     layerType: "Input",
34                                     inputSlots: [{
35                                         index: 0,
36                                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
37                                     }],
38                                     outputSlots: [{
39                                         index: 0,
40                                         tensorInfo: {
41                                             dimensions: )" + inputShape + R"(,
42                                             dataType: )" + dataType + R"(
43                                         }
44                                     }]
45                                 }
46                             }
47                         }
48                     },
49                     {
50                         layer_type: "MeanLayer",
51                         layer: {
52                             base: {
53                                 index: 1,
54                                 layerName: "MeanLayer",
55                                 layerType: "Mean",
56                                 inputSlots: [{
57                                     index: 0,
58                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
59                                 }],
60                                 outputSlots: [{
61                                     index: 0,
62                                     tensorInfo: {
63                                         dimensions: )" + outputShape + R"(,
64                                         dataType: )" + dataType + R"(
65                                     }
66                                 }]
67                             },
68                             descriptor: {
69                                 axis: )" + axis + R"(,
70                                 keepDims: true
71                             }
72                         }
73                     },
74                     {
75                         layer_type: "OutputLayer",
76                         layer: {
77                             base:{
78                                 layerBindingId: 2,
79                                 base: {
80                                     index: 2,
81                                     layerName: "OutputLayer",
82                                     layerType: "Output",
83                                     inputSlots: [{
84                                         index: 0,
85                                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
86                                     }],
87                                     outputSlots: [{
88                                         index: 0,
89                                         tensorInfo: {
90                                             dimensions: )" + outputShape + R"(,
91                                             dataType: )" + dataType + R"(
92                                         },
93                                     }],
94                                 }
95                             }
96                         },
97                     }
98                 ]
99             }
100         )";
101         Setup();
102     }
103 };
104 
105 struct SimpleMeanFixture : MeanFixture
106 {
SimpleMeanFixtureSimpleMeanFixture107     SimpleMeanFixture()
108         : MeanFixture("[ 1, 1, 3, 2 ]",     // inputShape
109                       "[ 1, 1, 1, 2 ]",     // outputShape
110                       "[ 2 ]",              // axis
111                       "Float32")            // dataType
112     {}
113 };
114 
115 TEST_CASE_FIXTURE(SimpleMeanFixture, "SimpleMean")
116 {
117     RunTest<4, armnn::DataType::Float32>(
118          0,
119          {{"InputLayer",  { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}},
120          {{"OutputLayer", { 2.0f, 2.0f }}});
121 }
122 
123 }