xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeNormalization.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Normalization")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct NormalizationFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
NormalizationFixtureNormalizationFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit NormalizationFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker         const std::string & outputShape,
17*89c4ff92SAndroid Build Coastguard Worker         const std::string &dataType,
18*89c4ff92SAndroid Build Coastguard Worker         const std::string &normAlgorithmChannel,
19*89c4ff92SAndroid Build Coastguard Worker         const std::string &normAlgorithmMethod,
20*89c4ff92SAndroid Build Coastguard Worker         const std::string &dataLayout)
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
23*89c4ff92SAndroid Build Coastguard Worker         {
24*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0],
25*89c4ff92SAndroid Build Coastguard Worker             outputIds: [2],
26*89c4ff92SAndroid Build Coastguard Worker             layers: [{
27*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
28*89c4ff92SAndroid Build Coastguard Worker                 layer: {
29*89c4ff92SAndroid Build Coastguard Worker                     base: {
30*89c4ff92SAndroid Build Coastguard Worker                         layerBindingId: 0,
31*89c4ff92SAndroid Build Coastguard Worker                         base: {
32*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
33*89c4ff92SAndroid Build Coastguard Worker                             layerName: "InputLayer",
34*89c4ff92SAndroid Build Coastguard Worker                             layerType: "Input",
35*89c4ff92SAndroid Build Coastguard Worker                             inputSlots: [{
36*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
37*89c4ff92SAndroid Build Coastguard Worker                                 connection: {sourceLayerIndex:0, outputSlotIndex:0 },
38*89c4ff92SAndroid Build Coastguard Worker                                 }],
39*89c4ff92SAndroid Build Coastguard Worker                             outputSlots: [{
40*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
41*89c4ff92SAndroid Build Coastguard Worker                                 tensorInfo: {
42*89c4ff92SAndroid Build Coastguard Worker                                     dimensions: )" + inputShape + R"(,
43*89c4ff92SAndroid Build Coastguard Worker                                     dataType: )" + dataType + R"(,
44*89c4ff92SAndroid Build Coastguard Worker                                     quantizationScale: 0.5,
45*89c4ff92SAndroid Build Coastguard Worker                                     quantizationOffset: 0
46*89c4ff92SAndroid Build Coastguard Worker                                     },
47*89c4ff92SAndroid Build Coastguard Worker                                 }]
48*89c4ff92SAndroid Build Coastguard Worker                             },
49*89c4ff92SAndroid Build Coastguard Worker                         }
50*89c4ff92SAndroid Build Coastguard Worker                     },
51*89c4ff92SAndroid Build Coastguard Worker                 },
52*89c4ff92SAndroid Build Coastguard Worker             {
53*89c4ff92SAndroid Build Coastguard Worker             layer_type: "NormalizationLayer",
54*89c4ff92SAndroid Build Coastguard Worker             layer : {
55*89c4ff92SAndroid Build Coastguard Worker                 base: {
56*89c4ff92SAndroid Build Coastguard Worker                     index:1,
57*89c4ff92SAndroid Build Coastguard Worker                     layerName: "NormalizationLayer",
58*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Normalization",
59*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [{
60*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
61*89c4ff92SAndroid Build Coastguard Worker                             connection: {sourceLayerIndex:0, outputSlotIndex:0 },
62*89c4ff92SAndroid Build Coastguard Worker                         }],
63*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [{
64*89c4ff92SAndroid Build Coastguard Worker                         index: 0,
65*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
66*89c4ff92SAndroid Build Coastguard Worker                             dimensions: )" + outputShape + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                             dataType: )" + dataType + R"(
68*89c4ff92SAndroid Build Coastguard Worker                         },
69*89c4ff92SAndroid Build Coastguard Worker                         }],
70*89c4ff92SAndroid Build Coastguard Worker                     },
71*89c4ff92SAndroid Build Coastguard Worker                 descriptor: {
72*89c4ff92SAndroid Build Coastguard Worker                     normChannelType: )" + normAlgorithmChannel + R"(,
73*89c4ff92SAndroid Build Coastguard Worker                     normMethodType: )" + normAlgorithmMethod + R"(,
74*89c4ff92SAndroid Build Coastguard Worker                     normSize: 3,
75*89c4ff92SAndroid Build Coastguard Worker                     alpha: 1,
76*89c4ff92SAndroid Build Coastguard Worker                     beta: 1,
77*89c4ff92SAndroid Build Coastguard Worker                     k: 1,
78*89c4ff92SAndroid Build Coastguard Worker                     dataLayout: )" + dataLayout + R"(
79*89c4ff92SAndroid Build Coastguard Worker                     }
80*89c4ff92SAndroid Build Coastguard Worker                 },
81*89c4ff92SAndroid Build Coastguard Worker             },
82*89c4ff92SAndroid Build Coastguard Worker             {
83*89c4ff92SAndroid Build Coastguard Worker             layer_type: "OutputLayer",
84*89c4ff92SAndroid Build Coastguard Worker             layer: {
85*89c4ff92SAndroid Build Coastguard Worker                 base:{
86*89c4ff92SAndroid Build Coastguard Worker                     layerBindingId: 0,
87*89c4ff92SAndroid Build Coastguard Worker                     base: {
88*89c4ff92SAndroid Build Coastguard Worker                         index: 2,
89*89c4ff92SAndroid Build Coastguard Worker                         layerName: "OutputLayer",
90*89c4ff92SAndroid Build Coastguard Worker                         layerType: "Output",
91*89c4ff92SAndroid Build Coastguard Worker                         inputSlots: [{
92*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
93*89c4ff92SAndroid Build Coastguard Worker                             connection: {sourceLayerIndex:1, outputSlotIndex:0 },
94*89c4ff92SAndroid Build Coastguard Worker                         }],
95*89c4ff92SAndroid Build Coastguard Worker                         outputSlots: [ {
96*89c4ff92SAndroid Build Coastguard Worker                             index: 0,
97*89c4ff92SAndroid Build Coastguard Worker                             tensorInfo: {
98*89c4ff92SAndroid Build Coastguard Worker                                 dimensions: )" + outputShape + R"(,
99*89c4ff92SAndroid Build Coastguard Worker                                 dataType: )" + dataType + R"(
100*89c4ff92SAndroid Build Coastguard Worker                             },
101*89c4ff92SAndroid Build Coastguard Worker                         }],
102*89c4ff92SAndroid Build Coastguard Worker                     }
103*89c4ff92SAndroid Build Coastguard Worker                 }},
104*89c4ff92SAndroid Build Coastguard Worker             }]
105*89c4ff92SAndroid Build Coastguard Worker         }
106*89c4ff92SAndroid Build Coastguard Worker  )";
107*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
108*89c4ff92SAndroid Build Coastguard Worker     }
109*89c4ff92SAndroid Build Coastguard Worker };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker struct FloatNhwcLocalBrightnessAcrossNormalizationFixture : NormalizationFixture
112*89c4ff92SAndroid Build Coastguard Worker {
FloatNhwcLocalBrightnessAcrossNormalizationFixtureFloatNhwcLocalBrightnessAcrossNormalizationFixture113*89c4ff92SAndroid Build Coastguard Worker     FloatNhwcLocalBrightnessAcrossNormalizationFixture() : NormalizationFixture("[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]",
114*89c4ff92SAndroid Build Coastguard Worker         "Float32", "0", "0", "NHWC") {}
115*89c4ff92SAndroid Build Coastguard Worker };
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FloatNhwcLocalBrightnessAcrossNormalizationFixture, "Float32NormalizationNhwcDataLayout")
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 1.0f, 2.0f, 3.0f, 4.0f,
121*89c4ff92SAndroid Build Coastguard Worker                                               5.0f, 6.0f, 7.0f, 8.0f },
122*89c4ff92SAndroid Build Coastguard Worker                                             { 0.5f, 0.400000006f, 0.300000012f, 0.235294119f,
123*89c4ff92SAndroid Build Coastguard Worker                                               0.192307696f, 0.16216217f, 0.140000001f, 0.123076923f });
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker struct FloatNchwLocalBrightnessWithinNormalizationFixture : NormalizationFixture
127*89c4ff92SAndroid Build Coastguard Worker {
FloatNchwLocalBrightnessWithinNormalizationFixtureFloatNchwLocalBrightnessWithinNormalizationFixture128*89c4ff92SAndroid Build Coastguard Worker     FloatNchwLocalBrightnessWithinNormalizationFixture() : NormalizationFixture("[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]",
129*89c4ff92SAndroid Build Coastguard Worker         "Float32", "1", "0", "NCHW") {}
130*89c4ff92SAndroid Build Coastguard Worker };
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FloatNchwLocalBrightnessWithinNormalizationFixture, "Float32NormalizationNchwDataLayout")
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 1.0f, 2.0f, 3.0f, 4.0f,
135*89c4ff92SAndroid Build Coastguard Worker                                               5.0f, 6.0f, 7.0f, 8.0f },
136*89c4ff92SAndroid Build Coastguard Worker                                             { 0.0322581f, 0.0645161f, 0.0967742f, 0.1290323f,
137*89c4ff92SAndroid Build Coastguard Worker                                               0.0285714f, 0.0342857f, 0.04f, 0.0457143f });
138*89c4ff92SAndroid Build Coastguard Worker }
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker }