xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/BatchNorm.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_BatchNorm")
10 {
11 struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
BatchNormalizationMainFixtureBatchNormalizationMainFixture13     BatchNormalizationMainFixture()
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input"
25                         type {
26                           tensor_type {
27                             elem_type: 1
28                             shape {
29                               dim {
30                                 dim_value: 1
31                               }
32                               dim {
33                                 dim_value: 1
34                               }
35                               dim {
36                                 dim_value: 3
37                               }
38                               dim {
39                                 dim_value: 3
40                               }
41                             }
42                           }
43                         }
44                       }
45                       input {
46                          name: "mean"
47                          type {
48                            tensor_type {
49                              elem_type: 1
50                              shape {
51                                dim {
52                                  dim_value: 1
53                                }
54                              }
55                            }
56                          }
57                        }
58                        input {
59                           name: "var"
60                           type {
61                             tensor_type {
62                               elem_type: 1
63                               shape {
64                                 dim {
65                                   dim_value: 1
66                                 }
67                               }
68                             }
69                           }
70                         }
71                         input {
72                            name: "scale"
73                            type {
74                              tensor_type {
75                                elem_type: 1
76                                shape {
77                                  dim {
78                                    dim_value: 1
79                                  }
80                                }
81                              }
82                            }
83                          }
84                          input {
85                             name: "bias"
86                             type {
87                               tensor_type {
88                                 elem_type: 1
89                                 shape {
90                                   dim {
91                                     dim_value: 1
92                                   }
93                                 }
94                               }
95                             }
96                           }
97                      node {
98                          input: "Input"
99                          input: "scale"
100                          input: "bias"
101                          input: "mean"
102                          input: "var"
103                          output: "Output"
104                          name: "batchNorm"
105                          op_type: "BatchNormalization"
106                          attribute {
107                            name: "epsilon"
108                            f:  0.0010000000475
109                            type: 1
110                          }
111                       }
112                       initializer {
113                           dims: 1
114                           data_type: 1
115                           float_data: 5.0
116                           name: "mean"
117                         }
118                       initializer {
119                         dims: 1
120                         data_type: 1
121                         float_data: 2.0
122                         name: "var"
123                       }
124                       initializer {
125                         dims: 1
126                         data_type: 1
127                         float_data: 0.0
128                         name: "bias"
129                       }
130                       initializer {
131                         dims: 1
132                         data_type: 1
133                         float_data: 1.0
134                         name: "scale"
135                       }
136                       output {
137                           name: "Output"
138                           type {
139                              tensor_type {
140                                elem_type: 1
141                                shape {
142                                    dim {
143                                        dim_value: 1
144                                    }
145                                    dim {
146                                        dim_value: 1
147                                    }
148                                    dim {
149                                        dim_value: 3
150                                    }
151                                    dim {
152                                        dim_value: 3
153                                    }
154                                }
155                             }
156                         }
157                         }
158                     }
159                    opset_import {
160                       version: 7
161                     })";
162         Setup();
163     }
164 };
165 
166 TEST_CASE_FIXTURE(BatchNormalizationMainFixture, "ValidBatchNormalizationTest")
167 {
168     RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}},             // Input data.
169                {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f,
170                 -0.7069301f, 0.0f, 0.7069301f,
171                 1.4138602f, 2.12079024f, 2.8277204f}}});  // Expected output data.
172 }
173 
174 
175 struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
176 {
BatchNormalizationBisFixtureBatchNormalizationBisFixture177     BatchNormalizationBisFixture()
178     {
179         m_Prototext = R"(
180                    ir_version: 3
181                    producer_name:  "CNTK"
182                    producer_version:  "2.5.1"
183                    domain:  "ai.cntk"
184                    model_version: 1
185                    graph {
186                      name:  "CNTKGraph"
187                      input {
188                         name: "Input"
189                         type {
190                           tensor_type {
191                             elem_type: 1
192                             shape {
193                               dim {
194                                 dim_value: 1
195                               }
196                               dim {
197                                 dim_value: 2
198                               }
199                               dim {
200                                 dim_value: 1
201                               }
202                               dim {
203                                 dim_value: 3
204                               }
205                             }
206                           }
207                         }
208                       }
209                       input {
210                          name: "mean"
211                          type {
212                            tensor_type {
213                              elem_type: 1
214                              shape {
215                                dim {
216                                  dim_value: 2
217                                }
218                              }
219                            }
220                          }
221                        }
222                        input {
223                           name: "var"
224                           type {
225                             tensor_type {
226                               elem_type: 1
227                               shape {
228                                 dim {
229                                   dim_value: 2
230                                 }
231                               }
232                             }
233                           }
234                         }
235                         input {
236                            name: "scale"
237                            type {
238                              tensor_type {
239                                elem_type: 1
240                                shape {
241                                  dim {
242                                    dim_value: 2
243                                  }
244                                }
245                              }
246                            }
247                          }
248                          input {
249                             name: "bias"
250                             type {
251                               tensor_type {
252                                 elem_type: 1
253                                 shape {
254                                   dim {
255                                     dim_value: 2
256                                   }
257                                 }
258                               }
259                             }
260                           }
261                      node {
262                          input: "Input"
263                          input: "scale"
264                          input: "bias"
265                          input: "mean"
266                          input: "var"
267                          output: "Output"
268                          name: "batchNorm"
269                          op_type: "BatchNormalization"
270                          attribute {
271                            name: "epsilon"
272                            f:  0.00001
273                            type: 1
274                          }
275                       }
276                       initializer {
277                           dims: 2
278                           data_type: 1
279                           float_data: 0.0
280                           float_data: 3.0
281                           name: "mean"
282                         }
283                       initializer {
284                         dims: 2
285                         data_type: 1
286                         float_data: 1.0
287                         float_data: 1.5
288                         name: "var"
289                       }
290                       initializer {
291                         dims: 2
292                         data_type: 1
293                         float_data: 0.0
294                         float_data: 1.0
295                         name: "bias"
296                       }
297                       initializer {
298                         dims: 2
299                         data_type: 1
300                         float_data: 1.0
301                         float_data: 1.5
302                         name: "scale"
303                       }
304                       output {
305                           name: "Output"
306                           type {
307                              tensor_type {
308                                elem_type: 1
309                                shape {
310                                    dim {
311                                        dim_value: 1
312                                    }
313                                    dim {
314                                        dim_value: 2
315                                    }
316                                    dim {
317                                        dim_value: 1
318                                    }
319                                    dim {
320                                        dim_value: 3
321                                    }
322                                }
323                             }
324                         }
325                         }
326                     }
327                    opset_import {
328                       version: 7
329                     })";
330         Setup();
331     }
332 };
333 
334 TEST_CASE_FIXTURE(BatchNormalizationBisFixture, "ValidBatchNormalizationBisTest")
335 {
336     RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}},           // Input data.
337                {{"Output", {-0.999995f, 0.0, 0.999995f,
338                             -0.22474074f, 1.0f, 2.2247407f}}});  // Expected output data.
339 }
340 
341 }
342