xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeBatchNormalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_BatchNormalization")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct BatchNormalizationFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
BatchNormalizationFixtureBatchNormalizationFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit BatchNormalizationFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &outputShape,
17*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &meanShape,
18*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &varianceShape,
19*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &offsetShape,
20*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &scaleShape,
21*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &dataType,
22*89c4ff92SAndroid Build Coastguard Worker                                        const std::string &dataLayout)
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
25*89c4ff92SAndroid Build Coastguard Worker     {
26*89c4ff92SAndroid Build Coastguard Worker         inputIds: [0],
27*89c4ff92SAndroid Build Coastguard Worker         outputIds: [2],
28*89c4ff92SAndroid Build Coastguard Worker         layers: [
29*89c4ff92SAndroid Build Coastguard Worker            {
30*89c4ff92SAndroid Build Coastguard Worker             layer_type: "InputLayer",
31*89c4ff92SAndroid Build Coastguard Worker             layer: {
32*89c4ff92SAndroid Build Coastguard Worker                 base: {
33*89c4ff92SAndroid Build Coastguard Worker                     layerBindingId: 0,
34*89c4ff92SAndroid Build Coastguard Worker                     base: {
35*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
36*89c4ff92SAndroid Build Coastguard Worker                         layerName: "InputLayer",
37*89c4ff92SAndroid Build Coastguard Worker                         layerType: "Input",
38*89c4ff92SAndroid Build Coastguard Worker                         inputSlots: [{
39*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
40*89c4ff92SAndroid Build Coastguard Worker                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
41*89c4ff92SAndroid Build Coastguard Worker                             }],
42*89c4ff92SAndroid Build Coastguard Worker                         outputSlots: [{
43*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
44*89c4ff92SAndroid Build Coastguard Worker                             tensorInfo: {
45*89c4ff92SAndroid Build Coastguard Worker                                 dimensions: )" + inputShape + R"(,
46*89c4ff92SAndroid Build Coastguard Worker                                 dataType: ")" + dataType + R"(",
47*89c4ff92SAndroid Build Coastguard Worker                                 quantizationScale: 0.5,
48*89c4ff92SAndroid Build Coastguard Worker                                 quantizationOffset: 0
49*89c4ff92SAndroid Build Coastguard Worker                                 },
50*89c4ff92SAndroid Build Coastguard Worker                             }]
51*89c4ff92SAndroid Build Coastguard Worker                         },
52*89c4ff92SAndroid Build Coastguard Worker                     }
53*89c4ff92SAndroid Build Coastguard Worker                 },
54*89c4ff92SAndroid Build Coastguard Worker             },
55*89c4ff92SAndroid Build Coastguard Worker         {
56*89c4ff92SAndroid Build Coastguard Worker         layer_type: "BatchNormalizationLayer",
57*89c4ff92SAndroid Build Coastguard Worker         layer : {
58*89c4ff92SAndroid Build Coastguard Worker             base: {
59*89c4ff92SAndroid Build Coastguard Worker                 index:1,
60*89c4ff92SAndroid Build Coastguard Worker                 layerName: "BatchNormalizationLayer",
61*89c4ff92SAndroid Build Coastguard Worker                 layerType: "BatchNormalization",
62*89c4ff92SAndroid Build Coastguard Worker                 inputSlots: [{
63*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
64*89c4ff92SAndroid Build Coastguard Worker                         connection: {sourceLayerIndex:0, outputSlotIndex:0 },
65*89c4ff92SAndroid Build Coastguard Worker                    }],
66*89c4ff92SAndroid Build Coastguard Worker                 outputSlots: [{
67*89c4ff92SAndroid Build Coastguard Worker                     index: 0,
68*89c4ff92SAndroid Build Coastguard Worker                     tensorInfo: {
69*89c4ff92SAndroid Build Coastguard Worker                         dimensions: )" + outputShape + R"(,
70*89c4ff92SAndroid Build Coastguard Worker                         dataType: ")" + dataType + R"("
71*89c4ff92SAndroid Build Coastguard Worker                     },
72*89c4ff92SAndroid Build Coastguard Worker                     }],
73*89c4ff92SAndroid Build Coastguard Worker                 },
74*89c4ff92SAndroid Build Coastguard Worker             descriptor: {
75*89c4ff92SAndroid Build Coastguard Worker                 eps: 0.0010000000475,
76*89c4ff92SAndroid Build Coastguard Worker                 dataLayout: ")" + dataLayout + R"("
77*89c4ff92SAndroid Build Coastguard Worker                 },
78*89c4ff92SAndroid Build Coastguard Worker             mean: {
79*89c4ff92SAndroid Build Coastguard Worker                 info: {
80*89c4ff92SAndroid Build Coastguard Worker                          dimensions: )" + meanShape + R"(,
81*89c4ff92SAndroid Build Coastguard Worker                          dataType: ")" + dataType + R"("
82*89c4ff92SAndroid Build Coastguard Worker                      },
83*89c4ff92SAndroid Build Coastguard Worker                 data_type: IntData,
84*89c4ff92SAndroid Build Coastguard Worker                 data: {
85*89c4ff92SAndroid Build Coastguard Worker                     data: [1084227584],
86*89c4ff92SAndroid Build Coastguard Worker                     }
87*89c4ff92SAndroid Build Coastguard Worker                 },
88*89c4ff92SAndroid Build Coastguard Worker             variance: {
89*89c4ff92SAndroid Build Coastguard Worker                 info: {
90*89c4ff92SAndroid Build Coastguard Worker                          dimensions: )" + varianceShape + R"(,
91*89c4ff92SAndroid Build Coastguard Worker                          dataType: ")" + dataType + R"("
92*89c4ff92SAndroid Build Coastguard Worker                      },
93*89c4ff92SAndroid Build Coastguard Worker                data_type: IntData,
94*89c4ff92SAndroid Build Coastguard Worker                 data: {
95*89c4ff92SAndroid Build Coastguard Worker                     data: [1073741824],
96*89c4ff92SAndroid Build Coastguard Worker                     }
97*89c4ff92SAndroid Build Coastguard Worker                 },
98*89c4ff92SAndroid Build Coastguard Worker             beta: {
99*89c4ff92SAndroid Build Coastguard Worker                 info: {
100*89c4ff92SAndroid Build Coastguard Worker                          dimensions: )" + offsetShape + R"(,
101*89c4ff92SAndroid Build Coastguard Worker                          dataType: ")" + dataType + R"("
102*89c4ff92SAndroid Build Coastguard Worker                      },
103*89c4ff92SAndroid Build Coastguard Worker                 data_type: IntData,
104*89c4ff92SAndroid Build Coastguard Worker                 data: {
105*89c4ff92SAndroid Build Coastguard Worker                     data: [0],
106*89c4ff92SAndroid Build Coastguard Worker                     }
107*89c4ff92SAndroid Build Coastguard Worker                 },
108*89c4ff92SAndroid Build Coastguard Worker             gamma: {
109*89c4ff92SAndroid Build Coastguard Worker                 info: {
110*89c4ff92SAndroid Build Coastguard Worker                          dimensions: )" + scaleShape + R"(,
111*89c4ff92SAndroid Build Coastguard Worker                          dataType: ")" + dataType + R"("
112*89c4ff92SAndroid Build Coastguard Worker                      },
113*89c4ff92SAndroid Build Coastguard Worker                 data_type: IntData,
114*89c4ff92SAndroid Build Coastguard Worker                 data: {
115*89c4ff92SAndroid Build Coastguard Worker                     data: [1065353216],
116*89c4ff92SAndroid Build Coastguard Worker                     }
117*89c4ff92SAndroid Build Coastguard Worker                 },
118*89c4ff92SAndroid Build Coastguard Worker             },
119*89c4ff92SAndroid Build Coastguard Worker         },
120*89c4ff92SAndroid Build Coastguard Worker         {
121*89c4ff92SAndroid Build Coastguard Worker         layer_type: "OutputLayer",
122*89c4ff92SAndroid Build Coastguard Worker         layer: {
123*89c4ff92SAndroid Build Coastguard Worker             base:{
124*89c4ff92SAndroid Build Coastguard Worker                 layerBindingId: 0,
125*89c4ff92SAndroid Build Coastguard Worker                 base: {
126*89c4ff92SAndroid Build Coastguard Worker                     index: 2,
127*89c4ff92SAndroid Build Coastguard Worker                     layerName: "OutputLayer",
128*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Output",
129*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [{
130*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
131*89c4ff92SAndroid Build Coastguard Worker                         connection: {sourceLayerIndex:1, outputSlotIndex:0 },
132*89c4ff92SAndroid Build Coastguard Worker                     }],
133*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [ {
134*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
135*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
136*89c4ff92SAndroid Build Coastguard Worker                             dimensions: )" + outputShape + R"(,
137*89c4ff92SAndroid Build Coastguard Worker                             dataType: ")" + dataType + R"("
138*89c4ff92SAndroid Build Coastguard Worker                         },
139*89c4ff92SAndroid Build Coastguard Worker                     }],
140*89c4ff92SAndroid Build Coastguard Worker                 }
141*89c4ff92SAndroid Build Coastguard Worker             }},
142*89c4ff92SAndroid Build Coastguard Worker         }]
143*89c4ff92SAndroid Build Coastguard Worker     }
144*89c4ff92SAndroid Build Coastguard Worker )";
145*89c4ff92SAndroid Build Coastguard Worker         Setup();
146*89c4ff92SAndroid Build Coastguard Worker     }
147*89c4ff92SAndroid Build Coastguard Worker };
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker struct BatchNormFixture : BatchNormalizationFixture
150*89c4ff92SAndroid Build Coastguard Worker {
BatchNormFixtureBatchNormFixture151*89c4ff92SAndroid Build Coastguard Worker     BatchNormFixture():BatchNormalizationFixture("[ 1, 3, 3, 1 ]",
152*89c4ff92SAndroid Build Coastguard Worker                                                  "[ 1, 3, 3, 1 ]",
153*89c4ff92SAndroid Build Coastguard Worker                                                  "[ 1 ]",
154*89c4ff92SAndroid Build Coastguard Worker                                                  "[ 1 ]",
155*89c4ff92SAndroid Build Coastguard Worker                                                  "[ 1 ]",
156*89c4ff92SAndroid Build Coastguard Worker                                                  "[ 1 ]",
157*89c4ff92SAndroid Build Coastguard Worker                                                  "Float32",
158*89c4ff92SAndroid Build Coastguard Worker                                                  "NHWC"){}
159*89c4ff92SAndroid Build Coastguard Worker };
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(BatchNormFixture, "BatchNormalizationFloat32")
162*89c4ff92SAndroid Build Coastguard Worker {
163*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0,
164*89c4ff92SAndroid Build Coastguard Worker                                          {{"InputLayer", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }}},
165*89c4ff92SAndroid Build Coastguard Worker                                          {{"OutputLayer",{ -2.8277204f, -2.12079024f, -1.4138602f,
166*89c4ff92SAndroid Build Coastguard Worker                                            -0.7069301f,  0.0f,         0.7069301f,
167*89c4ff92SAndroid Build Coastguard Worker                                            1.4138602f,  2.12079024f,  2.8277204f }}});
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker }
171