1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 2# SPDX-License-Identifier: MIT 3import inspect 4 5import pytest 6 7import pyarmnn as ann 8import numpy as np 9import pyarmnn._generated.pyarmnn as generated 10 11 12def test_activation_descriptor_default_values(): 13 desc = ann.ActivationDescriptor() 14 assert desc.m_Function == ann.ActivationFunction_Sigmoid 15 assert desc.m_A == 0 16 assert desc.m_B == 0 17 18 19def test_argminmax_descriptor_default_values(): 20 desc = ann.ArgMinMaxDescriptor() 21 assert desc.m_Function == ann.ArgMinMaxFunction_Min 22 assert desc.m_Axis == -1 23 24 25def test_batchnormalization_descriptor_default_values(): 26 desc = ann.BatchNormalizationDescriptor() 27 assert desc.m_DataLayout == ann.DataLayout_NCHW 28 np.allclose(0.0001, desc.m_Eps) 29 30 31def test_batchtospacend_descriptor_default_values(): 32 desc = ann.BatchToSpaceNdDescriptor() 33 assert desc.m_DataLayout == ann.DataLayout_NCHW 34 assert [1, 1] == desc.m_BlockShape 35 assert [(0, 0), (0, 0)] == desc.m_Crops 36 37 38def test_batchtospacend_descriptor_assignment(): 39 desc = ann.BatchToSpaceNdDescriptor() 40 desc.m_BlockShape = (1, 2, 3) 41 42 ololo = [(1, 2), (3, 4)] 43 size_1 = len(ololo) 44 desc.m_Crops = ololo 45 46 assert size_1 == len(ololo) 47 desc.m_DataLayout = ann.DataLayout_NHWC 48 assert ann.DataLayout_NHWC == desc.m_DataLayout 49 assert [1, 2, 3] == desc.m_BlockShape 50 assert [(1, 2), (3, 4)] == desc.m_Crops 51 52 53@pytest.mark.parametrize("input_shape, value, vtype", [([-1], -1, 'int'), (("one", "two"), "'one'", 'str'), 54 ([1.33, 4.55], 1.33, 'float'), 55 ([{1: "one"}], "{1: 'one'}", 'dict')], ids=lambda x: str(x)) 56def test_batchtospacend_descriptor_rubbish_assignment_shape(input_shape, value, vtype): 57 desc = ann.BatchToSpaceNdDescriptor() 58 with pytest.raises(TypeError) as err: 59 desc.m_BlockShape = input_shape 60 61 assert "Failed to convert python input value {} of type '{}' to C type 'j'".format(value, vtype) in str(err.value) 62 63 64@pytest.mark.parametrize("input_crops, value, vtype", [([(1, 2), (3, 4, 5)], '(3, 4, 5)', 'tuple'), 65 ([(1, 'one')], "(1, 'one')", 'tuple'), 66 ([-1], -1, 'int'), 67 ([(1, (1, 2))], '(1, (1, 2))', 'tuple'), 68 ([[1, [1, 2]]], '[1, [1, 2]]', 'list') 69 ], ids=lambda x: str(x)) 70def test_batchtospacend_descriptor_rubbish_assignment_crops(input_crops, value, vtype): 71 desc = ann.BatchToSpaceNdDescriptor() 72 with pytest.raises(TypeError) as err: 73 desc.m_Crops = input_crops 74 75 assert "Failed to convert python input value {} of type '{}' to C type".format(value, vtype) in str(err.value) 76 77 78def test_batchtospacend_descriptor_empty_assignment(): 79 desc = ann.BatchToSpaceNdDescriptor() 80 desc.m_BlockShape = [] 81 assert [] == desc.m_BlockShape 82 83 84def test_batchtospacend_descriptor_ctor(): 85 desc = ann.BatchToSpaceNdDescriptor([1, 2, 3], [(4, 5), (6, 7)]) 86 assert desc.m_DataLayout == ann.DataLayout_NCHW 87 assert [1, 2, 3] == desc.m_BlockShape 88 assert [(4, 5), (6, 7)] == desc.m_Crops 89 90 91def test_channelshuffle_descriptor_default_values(): 92 desc = ann.ChannelShuffleDescriptor() 93 assert desc.m_Axis == 0 94 assert desc.m_NumGroups == 0 95 96def test_convolution2d_descriptor_default_values(): 97 desc = ann.Convolution2dDescriptor() 98 assert desc.m_PadLeft == 0 99 assert desc.m_PadTop == 0 100 assert desc.m_PadRight == 0 101 assert desc.m_PadBottom == 0 102 assert desc.m_StrideX == 1 103 assert desc.m_StrideY == 1 104 assert desc.m_DilationX == 1 105 assert desc.m_DilationY == 1 106 assert desc.m_BiasEnabled == False 107 assert desc.m_DataLayout == ann.DataLayout_NCHW 108 109def test_convolution3d_descriptor_default_values(): 110 desc = ann.Convolution3dDescriptor() 111 assert desc.m_PadLeft == 0 112 assert desc.m_PadTop == 0 113 assert desc.m_PadRight == 0 114 assert desc.m_PadBottom == 0 115 assert desc.m_PadFront == 0 116 assert desc.m_PadBack == 0 117 assert desc.m_StrideX == 1 118 assert desc.m_StrideY == 1 119 assert desc.m_StrideZ == 1 120 assert desc.m_DilationX == 1 121 assert desc.m_DilationY == 1 122 assert desc.m_DilationZ == 1 123 assert desc.m_BiasEnabled == False 124 assert desc.m_DataLayout == ann.DataLayout_NDHWC 125 126 127def test_depthtospace_descriptor_default_values(): 128 desc = ann.DepthToSpaceDescriptor() 129 assert desc.m_BlockSize == 1 130 assert desc.m_DataLayout == ann.DataLayout_NHWC 131 132 133def test_depthwise_convolution2d_descriptor_default_values(): 134 desc = ann.DepthwiseConvolution2dDescriptor() 135 assert desc.m_PadLeft == 0 136 assert desc.m_PadTop == 0 137 assert desc.m_PadRight == 0 138 assert desc.m_PadBottom == 0 139 assert desc.m_StrideX == 1 140 assert desc.m_StrideY == 1 141 assert desc.m_DilationX == 1 142 assert desc.m_DilationY == 1 143 assert desc.m_BiasEnabled == False 144 assert desc.m_DataLayout == ann.DataLayout_NCHW 145 146 147def test_detectionpostprocess_descriptor_default_values(): 148 desc = ann.DetectionPostProcessDescriptor() 149 assert desc.m_MaxDetections == 0 150 assert desc.m_MaxClassesPerDetection == 1 151 assert desc.m_DetectionsPerClass == 1 152 assert desc.m_NmsScoreThreshold == 0 153 assert desc.m_NmsIouThreshold == 0 154 assert desc.m_NumClasses == 0 155 assert desc.m_UseRegularNms == False 156 assert desc.m_ScaleH == 0 157 assert desc.m_ScaleW == 0 158 assert desc.m_ScaleX == 0 159 assert desc.m_ScaleY == 0 160 161 162def test_fakequantization_descriptor_default_values(): 163 desc = ann.FakeQuantizationDescriptor() 164 np.allclose(6, desc.m_Max) 165 np.allclose(-6, desc.m_Min) 166 167 168def test_fill_descriptor_default_values(): 169 desc = ann.FillDescriptor() 170 np.allclose(0, desc.m_Value) 171 172 173def test_gather_descriptor_default_values(): 174 desc = ann.GatherDescriptor() 175 assert desc.m_Axis == 0 176 177 178def test_fully_connected_descriptor_default_values(): 179 desc = ann.FullyConnectedDescriptor() 180 assert desc.m_BiasEnabled == False 181 assert desc.m_TransposeWeightMatrix == False 182 183 184def test_instancenormalization_descriptor_default_values(): 185 desc = ann.InstanceNormalizationDescriptor() 186 assert desc.m_Gamma == 1 187 assert desc.m_Beta == 0 188 assert desc.m_DataLayout == ann.DataLayout_NCHW 189 np.allclose(1e-12, desc.m_Eps) 190 191 192def test_lstm_descriptor_default_values(): 193 desc = ann.LstmDescriptor() 194 assert desc.m_ActivationFunc == 1 195 assert desc.m_ClippingThresCell == 0 196 assert desc.m_ClippingThresProj == 0 197 assert desc.m_CifgEnabled == True 198 assert desc.m_PeepholeEnabled == False 199 assert desc.m_ProjectionEnabled == False 200 assert desc.m_LayerNormEnabled == False 201 202 203def test_l2normalization_descriptor_default_values(): 204 desc = ann.L2NormalizationDescriptor() 205 assert desc.m_DataLayout == ann.DataLayout_NCHW 206 np.allclose(1e-12, desc.m_Eps) 207 208 209def test_mean_descriptor_default_values(): 210 desc = ann.MeanDescriptor() 211 assert desc.m_KeepDims == False 212 213 214def test_normalization_descriptor_default_values(): 215 desc = ann.NormalizationDescriptor() 216 assert desc.m_NormChannelType == ann.NormalizationAlgorithmChannel_Across 217 assert desc.m_NormMethodType == ann.NormalizationAlgorithmMethod_LocalBrightness 218 assert desc.m_NormSize == 0 219 assert desc.m_Alpha == 0 220 assert desc.m_Beta == 0 221 assert desc.m_K == 0 222 assert desc.m_DataLayout == ann.DataLayout_NCHW 223 224 225def test_origin_descriptor_default_values(): 226 desc = ann.ConcatDescriptor() 227 assert 0 == desc.GetNumViews() 228 assert 0 == desc.GetNumDimensions() 229 assert 1 == desc.GetConcatAxis() 230 231 232def test_origin_descriptor_incorrect_views(): 233 desc = ann.ConcatDescriptor(2, 2) 234 with pytest.raises(RuntimeError) as err: 235 desc.SetViewOriginCoord(1000, 100, 1000) 236 assert "Failed to set view origin coordinates." in str(err.value) 237 238 239def test_origin_descriptor_ctor(): 240 desc = ann.ConcatDescriptor(2, 2) 241 value = 5 242 for i in range(desc.GetNumViews()): 243 for j in range(desc.GetNumDimensions()): 244 desc.SetViewOriginCoord(i, j, value+i) 245 desc.SetConcatAxis(1) 246 247 assert 2 == desc.GetNumViews() 248 assert 2 == desc.GetNumDimensions() 249 assert [5, 5] == desc.GetViewOrigin(0) 250 assert [6, 6] == desc.GetViewOrigin(1) 251 assert 1 == desc.GetConcatAxis() 252 253 254def test_pad_descriptor_default_values(): 255 desc = ann.PadDescriptor() 256 assert desc.m_PadValue == 0 257 assert desc.m_PaddingMode == ann.PaddingMode_Constant 258 259 260def test_permute_descriptor_default_values(): 261 pv = ann.PermutationVector((0, 2, 3, 1)) 262 desc = ann.PermuteDescriptor(pv) 263 assert desc.m_DimMappings.GetSize() == 4 264 assert desc.m_DimMappings[0] == 0 265 assert desc.m_DimMappings[1] == 2 266 assert desc.m_DimMappings[2] == 3 267 assert desc.m_DimMappings[3] == 1 268 269 270def test_pooling_descriptor_default_values(): 271 desc = ann.Pooling2dDescriptor() 272 assert desc.m_PoolType == ann.PoolingAlgorithm_Max 273 assert desc.m_PadLeft == 0 274 assert desc.m_PadTop == 0 275 assert desc.m_PadRight == 0 276 assert desc.m_PadBottom == 0 277 assert desc.m_PoolHeight == 0 278 assert desc.m_PoolWidth == 0 279 assert desc.m_StrideX == 0 280 assert desc.m_StrideY == 0 281 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor 282 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude 283 assert desc.m_DataLayout == ann.DataLayout_NCHW 284 285def test_pooling_3d_descriptor_default_values(): 286 desc = ann.Pooling3dDescriptor() 287 assert desc.m_PoolType == ann.PoolingAlgorithm_Max 288 assert desc.m_PadLeft == 0 289 assert desc.m_PadTop == 0 290 assert desc.m_PadRight == 0 291 assert desc.m_PadBottom == 0 292 assert desc.m_PadFront == 0 293 assert desc.m_PadBack == 0 294 assert desc.m_PoolHeight == 0 295 assert desc.m_PoolWidth == 0 296 assert desc.m_StrideX == 0 297 assert desc.m_StrideY == 0 298 assert desc.m_StrideZ == 0 299 assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor 300 assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude 301 assert desc.m_DataLayout == ann.DataLayout_NCDHW 302 303 304def test_reshape_descriptor_default_values(): 305 desc = ann.ReshapeDescriptor() 306 # check the empty Targetshape 307 assert desc.m_TargetShape.GetNumDimensions() == 0 308 309def test_reduce_descriptor_default_values(): 310 desc = ann.ReduceDescriptor() 311 assert desc.m_KeepDims == False 312 assert desc.m_vAxis == [] 313 assert desc.m_ReduceOperation == ann.ReduceOperation_Sum 314 315def test_slice_descriptor_default_values(): 316 desc = ann.SliceDescriptor() 317 assert desc.m_TargetWidth == 0 318 assert desc.m_TargetHeight == 0 319 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor 320 assert desc.m_DataLayout == ann.DataLayout_NCHW 321 322 323def test_resize_descriptor_default_values(): 324 desc = ann.ResizeDescriptor() 325 assert desc.m_TargetWidth == 0 326 assert desc.m_TargetHeight == 0 327 assert desc.m_Method == ann.ResizeMethod_NearestNeighbor 328 assert desc.m_DataLayout == ann.DataLayout_NCHW 329 assert desc.m_AlignCorners == False 330 331 332def test_spacetobatchnd_descriptor_default_values(): 333 desc = ann.SpaceToBatchNdDescriptor() 334 assert desc.m_DataLayout == ann.DataLayout_NCHW 335 336 337def test_spacetodepth_descriptor_default_values(): 338 desc = ann.SpaceToDepthDescriptor() 339 assert desc.m_BlockSize == 1 340 assert desc.m_DataLayout == ann.DataLayout_NHWC 341 342 343def test_stack_descriptor_default_values(): 344 desc = ann.StackDescriptor() 345 assert desc.m_Axis == 0 346 assert desc.m_NumInputs == 0 347 # check the empty Inputshape 348 assert desc.m_InputShape.GetNumDimensions() == 0 349 350 351def test_slice_descriptor_default_values(): 352 desc = ann.SliceDescriptor() 353 desc.m_Begin = [1, 2, 3, 4, 5] 354 desc.m_Size = (1, 2, 3, 4) 355 356 assert [1, 2, 3, 4, 5] == desc.m_Begin 357 assert [1, 2, 3, 4] == desc.m_Size 358 359 360def test_slice_descriptor_ctor(): 361 desc = ann.SliceDescriptor([1, 2, 3, 4, 5], (1, 2, 3, 4)) 362 363 assert [1, 2, 3, 4, 5] == desc.m_Begin 364 assert [1, 2, 3, 4] == desc.m_Size 365 366 367def test_strided_slice_descriptor_default_values(): 368 desc = ann.StridedSliceDescriptor() 369 desc.m_Begin = [1, 2, 3, 4, 5] 370 desc.m_End = [6, 7, 8, 9, 10] 371 desc.m_Stride = (10, 10) 372 desc.m_BeginMask = 1 373 desc.m_EndMask = 2 374 desc.m_ShrinkAxisMask = 3 375 desc.m_EllipsisMask = 4 376 desc.m_NewAxisMask = 5 377 378 assert [1, 2, 3, 4, 5] == desc.m_Begin 379 assert [6, 7, 8, 9, 10] == desc.m_End 380 assert [10, 10] == desc.m_Stride 381 assert 1 == desc.m_BeginMask 382 assert 2 == desc.m_EndMask 383 assert 3 == desc.m_ShrinkAxisMask 384 assert 4 == desc.m_EllipsisMask 385 assert 5 == desc.m_NewAxisMask 386 387 388def test_strided_slice_descriptor_ctor(): 389 desc = ann.StridedSliceDescriptor([1, 2, 3, 4, 5], [6, 7, 8, 9, 10], (10, 10)) 390 desc.m_Begin = [1, 2, 3, 4, 5] 391 desc.m_End = [6, 7, 8, 9, 10] 392 desc.m_Stride = (10, 10) 393 394 assert [1, 2, 3, 4, 5] == desc.m_Begin 395 assert [6, 7, 8, 9, 10] == desc.m_End 396 assert [10, 10] == desc.m_Stride 397 398 399def test_softmax_descriptor_default_values(): 400 desc = ann.SoftmaxDescriptor() 401 assert desc.m_Axis == -1 402 np.allclose(1.0, desc.m_Beta) 403 404 405def test_space_to_batch_nd_descriptor_default_values(): 406 desc = ann.SpaceToBatchNdDescriptor() 407 assert [1, 1] == desc.m_BlockShape 408 assert [(0, 0), (0, 0)] == desc.m_PadList 409 assert ann.DataLayout_NCHW == desc.m_DataLayout 410 411 412def test_space_to_batch_nd_descriptor_assigned_values(): 413 desc = ann.SpaceToBatchNdDescriptor() 414 desc.m_BlockShape = (90, 100) 415 desc.m_PadList = [(1, 2), (3, 4)] 416 assert [90, 100] == desc.m_BlockShape 417 assert [(1, 2), (3, 4)] == desc.m_PadList 418 assert ann.DataLayout_NCHW == desc.m_DataLayout 419 420 421def test_space_to_batch_nd_descriptor_ctor(): 422 desc = ann.SpaceToBatchNdDescriptor((1, 2, 3), [(1, 2), (3, 4)]) 423 assert [1, 2, 3] == desc.m_BlockShape 424 assert [(1, 2), (3, 4)] == desc.m_PadList 425 assert ann.DataLayout_NCHW == desc.m_DataLayout 426 427 428def test_transpose_convolution2d_descriptor_default_values(): 429 desc = ann.TransposeConvolution2dDescriptor() 430 assert desc.m_PadLeft == 0 431 assert desc.m_PadTop == 0 432 assert desc.m_PadRight == 0 433 assert desc.m_PadBottom == 0 434 assert desc.m_StrideX == 0 435 assert desc.m_StrideY == 0 436 assert desc.m_BiasEnabled == False 437 assert desc.m_DataLayout == ann.DataLayout_NCHW 438 assert desc.m_OutputShapeEnabled == False 439 440def test_transpose_descriptor_default_values(): 441 pv = ann.PermutationVector((0, 3, 2, 1, 4)) 442 desc = ann.TransposeDescriptor(pv) 443 assert desc.m_DimMappings.GetSize() == 5 444 assert desc.m_DimMappings[0] == 0 445 assert desc.m_DimMappings[1] == 3 446 assert desc.m_DimMappings[2] == 2 447 assert desc.m_DimMappings[3] == 1 448 assert desc.m_DimMappings[4] == 4 449 450def test_view_descriptor_default_values(): 451 desc = ann.SplitterDescriptor() 452 assert 0 == desc.GetNumViews() 453 assert 0 == desc.GetNumDimensions() 454 455 456def test_elementwise_unary_descriptor_default_values(): 457 desc = ann.ElementwiseUnaryDescriptor() 458 assert desc.m_Operation == ann.UnaryOperation_Abs 459 460 461def test_logical_binary_descriptor_default_values(): 462 desc = ann.LogicalBinaryDescriptor() 463 assert desc.m_Operation == ann.LogicalBinaryOperation_LogicalAnd 464 465def test_view_descriptor_incorrect_input(): 466 desc = ann.SplitterDescriptor(2, 3) 467 with pytest.raises(RuntimeError) as err: 468 desc.SetViewOriginCoord(1000, 100, 1000) 469 assert "Failed to set view origin coordinates." in str(err.value) 470 471 with pytest.raises(RuntimeError) as err: 472 desc.SetViewSize(1000, 100, 1000) 473 assert "Failed to set view size." in str(err.value) 474 475 476def test_view_descriptor_ctor(): 477 desc = ann.SplitterDescriptor(2, 3) 478 value_size = 1 479 value_orig_coord = 5 480 for i in range(desc.GetNumViews()): 481 for j in range(desc.GetNumDimensions()): 482 desc.SetViewOriginCoord(i, j, value_orig_coord+i) 483 desc.SetViewSize(i, j, value_size+i) 484 485 assert 2 == desc.GetNumViews() 486 assert 3 == desc.GetNumDimensions() 487 assert [5, 5] == desc.GetViewOrigin(0) 488 assert [6, 6] == desc.GetViewOrigin(1) 489 assert [1, 1] == desc.GetViewSizes(0) 490 assert [2, 2] == desc.GetViewSizes(1) 491 492 493def test_createdescriptorforconcatenation_ctor(): 494 input_shape_vector = [ann.TensorShape((2, 1)), ann.TensorShape((3, 1)), ann.TensorShape((4, 1))] 495 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0) 496 assert 3 == desc.GetNumViews() 497 assert 0 == desc.GetConcatAxis() 498 assert 2 == desc.GetNumDimensions() 499 c = desc.GetViewOrigin(1) 500 d = desc.GetViewOrigin(0) 501 502 503def test_createdescriptorforconcatenation_wrong_shape_for_axis(): 504 input_shape_vector = [ann.TensorShape((1, 2)), ann.TensorShape((3, 4)), ann.TensorShape((5, 6))] 505 with pytest.raises(RuntimeError) as err: 506 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0) 507 508 assert "All inputs to concatenation must be the same size along all dimensions except the concatenation dimension" in str( 509 err.value) 510 511 512@pytest.mark.parametrize("input_shape_vector", [([-1, "one"]), 513 ([1.33, 4.55]), 514 ([{1: "one"}])], ids=lambda x: str(x)) 515def test_createdescriptorforconcatenation_rubbish_assignment_shape_vector(input_shape_vector): 516 with pytest.raises(TypeError) as err: 517 desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0) 518 519 assert "in method 'CreateDescriptorForConcatenation', argument 1 of type 'std::vector< armnn::TensorShape,std::allocator< armnn::TensorShape > >'" in str( 520 err.value) 521 522 523generated_classes = inspect.getmembers(generated, inspect.isclass) 524generated_classes_names = list(map(lambda x: x[0], generated_classes)) 525@pytest.mark.parametrize("desc_name", ['ActivationDescriptor', 526 'ArgMinMaxDescriptor', 527 'PermuteDescriptor', 528 'SoftmaxDescriptor', 529 'ConcatDescriptor', 530 'SplitterDescriptor', 531 'Pooling2dDescriptor', 532 'FullyConnectedDescriptor', 533 'Convolution2dDescriptor', 534 'Convolution3dDescriptor', 535 'DepthwiseConvolution2dDescriptor', 536 'DetectionPostProcessDescriptor', 537 'NormalizationDescriptor', 538 'L2NormalizationDescriptor', 539 'BatchNormalizationDescriptor', 540 'InstanceNormalizationDescriptor', 541 'BatchToSpaceNdDescriptor', 542 'FakeQuantizationDescriptor', 543 'ReduceDescriptor', 544 'ResizeDescriptor', 545 'ReshapeDescriptor', 546 'SpaceToBatchNdDescriptor', 547 'SpaceToDepthDescriptor', 548 'LstmDescriptor', 549 'MeanDescriptor', 550 'PadDescriptor', 551 'SliceDescriptor', 552 'StackDescriptor', 553 'StridedSliceDescriptor', 554 'TransposeConvolution2dDescriptor', 555 'TransposeDescriptor', 556 'ElementwiseUnaryDescriptor', 557 'FillDescriptor', 558 'GatherDescriptor', 559 'LogicalBinaryDescriptor', 560 'ChannelShuffleDescriptor']) 561class TestDescriptorMassChecks: 562 563 def test_desc_implemented(self, desc_name): 564 assert desc_name in generated_classes_names 565 566 def test_desc_equal(self, desc_name): 567 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1] 568 569 assert desc_class() == desc_class() 570 571 572generated_classes = inspect.getmembers(generated, inspect.isclass) 573generated_classes_names = list(map(lambda x: x[0], generated_classes)) 574@pytest.mark.parametrize("desc_name", ['ActivationDescriptor', 575 'ArgMinMaxDescriptor', 576 'PermuteDescriptor', 577 'SoftmaxDescriptor', 578 'ConcatDescriptor', 579 'SplitterDescriptor', 580 'Pooling2dDescriptor', 581 'FullyConnectedDescriptor', 582 'Convolution2dDescriptor', 583 'Convolution3dDescriptor', 584 'DepthwiseConvolution2dDescriptor', 585 'DetectionPostProcessDescriptor', 586 'NormalizationDescriptor', 587 'L2NormalizationDescriptor', 588 'BatchNormalizationDescriptor', 589 'InstanceNormalizationDescriptor', 590 'BatchToSpaceNdDescriptor', 591 'FakeQuantizationDescriptor', 592 'ReduceDescriptor', 593 'ResizeDescriptor', 594 'ReshapeDescriptor', 595 'SpaceToBatchNdDescriptor', 596 'SpaceToDepthDescriptor', 597 'LstmDescriptor', 598 'MeanDescriptor', 599 'PadDescriptor', 600 'SliceDescriptor', 601 'StackDescriptor', 602 'StridedSliceDescriptor', 603 'TransposeConvolution2dDescriptor', 604 'TransposeDescriptor', 605 'ElementwiseUnaryDescriptor', 606 'FillDescriptor', 607 'GatherDescriptor', 608 'LogicalBinaryDescriptor', 609 'ChannelShuffleDescriptor']) 610class TestDescriptorMassChecks: 611 612 def test_desc_implemented(self, desc_name): 613 assert desc_name in generated_classes_names 614 615 def test_desc_equal(self, desc_name): 616 desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1] 617 618 assert desc_class() == desc_class() 619 620