1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "armnnOnnxParser/IOnnxParser.hpp" 7 #include "ParserPrototxtFixture.hpp" 8 9 TEST_SUITE("OnnxParser_BatchNorm") 10 { 11 struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12 { BatchNormalizationMainFixtureBatchNormalizationMainFixture13 BatchNormalizationMainFixture() 14 { 15 m_Prototext = R"( 16 ir_version: 3 17 producer_name: "CNTK" 18 producer_version: "2.5.1" 19 domain: "ai.cntk" 20 model_version: 1 21 graph { 22 name: "CNTKGraph" 23 input { 24 name: "Input" 25 type { 26 tensor_type { 27 elem_type: 1 28 shape { 29 dim { 30 dim_value: 1 31 } 32 dim { 33 dim_value: 1 34 } 35 dim { 36 dim_value: 3 37 } 38 dim { 39 dim_value: 3 40 } 41 } 42 } 43 } 44 } 45 input { 46 name: "mean" 47 type { 48 tensor_type { 49 elem_type: 1 50 shape { 51 dim { 52 dim_value: 1 53 } 54 } 55 } 56 } 57 } 58 input { 59 name: "var" 60 type { 61 tensor_type { 62 elem_type: 1 63 shape { 64 dim { 65 dim_value: 1 66 } 67 } 68 } 69 } 70 } 71 input { 72 name: "scale" 73 type { 74 tensor_type { 75 elem_type: 1 76 shape { 77 dim { 78 dim_value: 1 79 } 80 } 81 } 82 } 83 } 84 input { 85 name: "bias" 86 type { 87 tensor_type { 88 elem_type: 1 89 shape { 90 dim { 91 dim_value: 1 92 } 93 } 94 } 95 } 96 } 97 node { 98 input: "Input" 99 input: "scale" 100 input: "bias" 101 input: "mean" 102 input: "var" 103 output: "Output" 104 name: "batchNorm" 105 op_type: "BatchNormalization" 106 attribute { 107 name: "epsilon" 108 f: 0.0010000000475 109 type: 1 110 } 111 } 112 initializer { 113 dims: 1 114 data_type: 1 115 float_data: 5.0 116 name: "mean" 117 } 118 initializer { 119 dims: 1 120 data_type: 1 121 float_data: 2.0 122 name: "var" 123 } 124 initializer { 125 dims: 1 126 data_type: 1 127 float_data: 0.0 128 name: "bias" 129 } 130 initializer { 131 dims: 1 132 data_type: 1 133 float_data: 1.0 134 name: "scale" 135 } 136 output { 137 name: "Output" 138 type { 139 tensor_type { 140 elem_type: 1 141 shape { 142 dim { 143 dim_value: 1 144 } 145 dim { 146 dim_value: 1 147 } 148 dim { 149 dim_value: 3 150 } 151 dim { 152 dim_value: 3 153 } 154 } 155 } 156 } 157 } 158 } 159 opset_import { 160 version: 7 161 })"; 162 Setup(); 163 } 164 }; 165 166 TEST_CASE_FIXTURE(BatchNormalizationMainFixture, "ValidBatchNormalizationTest") 167 { 168 RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}}, // Input data. 169 {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f, 170 -0.7069301f, 0.0f, 0.7069301f, 171 1.4138602f, 2.12079024f, 2.8277204f}}}); // Expected output data. 172 } 173 174 175 struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 176 { BatchNormalizationBisFixtureBatchNormalizationBisFixture177 BatchNormalizationBisFixture() 178 { 179 m_Prototext = R"( 180 ir_version: 3 181 producer_name: "CNTK" 182 producer_version: "2.5.1" 183 domain: "ai.cntk" 184 model_version: 1 185 graph { 186 name: "CNTKGraph" 187 input { 188 name: "Input" 189 type { 190 tensor_type { 191 elem_type: 1 192 shape { 193 dim { 194 dim_value: 1 195 } 196 dim { 197 dim_value: 2 198 } 199 dim { 200 dim_value: 1 201 } 202 dim { 203 dim_value: 3 204 } 205 } 206 } 207 } 208 } 209 input { 210 name: "mean" 211 type { 212 tensor_type { 213 elem_type: 1 214 shape { 215 dim { 216 dim_value: 2 217 } 218 } 219 } 220 } 221 } 222 input { 223 name: "var" 224 type { 225 tensor_type { 226 elem_type: 1 227 shape { 228 dim { 229 dim_value: 2 230 } 231 } 232 } 233 } 234 } 235 input { 236 name: "scale" 237 type { 238 tensor_type { 239 elem_type: 1 240 shape { 241 dim { 242 dim_value: 2 243 } 244 } 245 } 246 } 247 } 248 input { 249 name: "bias" 250 type { 251 tensor_type { 252 elem_type: 1 253 shape { 254 dim { 255 dim_value: 2 256 } 257 } 258 } 259 } 260 } 261 node { 262 input: "Input" 263 input: "scale" 264 input: "bias" 265 input: "mean" 266 input: "var" 267 output: "Output" 268 name: "batchNorm" 269 op_type: "BatchNormalization" 270 attribute { 271 name: "epsilon" 272 f: 0.00001 273 type: 1 274 } 275 } 276 initializer { 277 dims: 2 278 data_type: 1 279 float_data: 0.0 280 float_data: 3.0 281 name: "mean" 282 } 283 initializer { 284 dims: 2 285 data_type: 1 286 float_data: 1.0 287 float_data: 1.5 288 name: "var" 289 } 290 initializer { 291 dims: 2 292 data_type: 1 293 float_data: 0.0 294 float_data: 1.0 295 name: "bias" 296 } 297 initializer { 298 dims: 2 299 data_type: 1 300 float_data: 1.0 301 float_data: 1.5 302 name: "scale" 303 } 304 output { 305 name: "Output" 306 type { 307 tensor_type { 308 elem_type: 1 309 shape { 310 dim { 311 dim_value: 1 312 } 313 dim { 314 dim_value: 2 315 } 316 dim { 317 dim_value: 1 318 } 319 dim { 320 dim_value: 3 321 } 322 } 323 } 324 } 325 } 326 } 327 opset_import { 328 version: 7 329 })"; 330 Setup(); 331 } 332 }; 333 334 TEST_CASE_FIXTURE(BatchNormalizationBisFixture, "ValidBatchNormalizationBisTest") 335 { 336 RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}}, // Input data. 337 {{"Output", {-0.999995f, 0.0, 0.999995f, 338 -0.22474074f, 1.0f, 2.2247407f}}}); // Expected output data. 339 } 340 341 } 342