xref: /aosp_15_r20/external/armnn/python/pyarmnn/test/test_descriptors.py (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1# Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3import inspect
4
5import pytest
6
7import pyarmnn as ann
8import numpy as np
9import pyarmnn._generated.pyarmnn as generated
10
11
12def test_activation_descriptor_default_values():
13    desc = ann.ActivationDescriptor()
14    assert desc.m_Function == ann.ActivationFunction_Sigmoid
15    assert desc.m_A == 0
16    assert desc.m_B == 0
17
18
19def test_argminmax_descriptor_default_values():
20    desc = ann.ArgMinMaxDescriptor()
21    assert desc.m_Function == ann.ArgMinMaxFunction_Min
22    assert desc.m_Axis == -1
23
24
25def test_batchnormalization_descriptor_default_values():
26    desc = ann.BatchNormalizationDescriptor()
27    assert desc.m_DataLayout == ann.DataLayout_NCHW
28    np.allclose(0.0001, desc.m_Eps)
29
30
31def test_batchtospacend_descriptor_default_values():
32    desc = ann.BatchToSpaceNdDescriptor()
33    assert desc.m_DataLayout == ann.DataLayout_NCHW
34    assert [1, 1] == desc.m_BlockShape
35    assert [(0, 0), (0, 0)] == desc.m_Crops
36
37
38def test_batchtospacend_descriptor_assignment():
39    desc = ann.BatchToSpaceNdDescriptor()
40    desc.m_BlockShape = (1, 2, 3)
41
42    ololo = [(1, 2), (3, 4)]
43    size_1 = len(ololo)
44    desc.m_Crops = ololo
45
46    assert size_1 == len(ololo)
47    desc.m_DataLayout = ann.DataLayout_NHWC
48    assert ann.DataLayout_NHWC == desc.m_DataLayout
49    assert [1, 2, 3] == desc.m_BlockShape
50    assert [(1, 2), (3, 4)] == desc.m_Crops
51
52
53@pytest.mark.parametrize("input_shape, value, vtype", [([-1], -1, 'int'), (("one", "two"), "'one'", 'str'),
54                                                       ([1.33, 4.55], 1.33, 'float'),
55                                                       ([{1: "one"}], "{1: 'one'}", 'dict')], ids=lambda x: str(x))
56def test_batchtospacend_descriptor_rubbish_assignment_shape(input_shape, value, vtype):
57    desc = ann.BatchToSpaceNdDescriptor()
58    with pytest.raises(TypeError) as err:
59        desc.m_BlockShape = input_shape
60
61    assert "Failed to convert python input value {} of type '{}' to C type 'j'".format(value, vtype) in str(err.value)
62
63
64@pytest.mark.parametrize("input_crops, value, vtype", [([(1, 2), (3, 4, 5)], '(3, 4, 5)', 'tuple'),
65                                                       ([(1, 'one')], "(1, 'one')", 'tuple'),
66                                                       ([-1], -1, 'int'),
67                                                       ([(1, (1, 2))], '(1, (1, 2))', 'tuple'),
68                                                       ([[1, [1, 2]]], '[1, [1, 2]]', 'list')
69                                                       ], ids=lambda x: str(x))
70def test_batchtospacend_descriptor_rubbish_assignment_crops(input_crops, value, vtype):
71    desc = ann.BatchToSpaceNdDescriptor()
72    with pytest.raises(TypeError) as err:
73        desc.m_Crops = input_crops
74
75    assert "Failed to convert python input value {} of type '{}' to C type".format(value, vtype) in str(err.value)
76
77
78def test_batchtospacend_descriptor_empty_assignment():
79    desc = ann.BatchToSpaceNdDescriptor()
80    desc.m_BlockShape = []
81    assert [] == desc.m_BlockShape
82
83
84def test_batchtospacend_descriptor_ctor():
85    desc = ann.BatchToSpaceNdDescriptor([1, 2, 3], [(4, 5), (6, 7)])
86    assert desc.m_DataLayout == ann.DataLayout_NCHW
87    assert [1, 2, 3] == desc.m_BlockShape
88    assert [(4, 5), (6, 7)] == desc.m_Crops
89
90
91def test_channelshuffle_descriptor_default_values():
92    desc = ann.ChannelShuffleDescriptor()
93    assert desc.m_Axis == 0
94    assert desc.m_NumGroups == 0
95
96def test_convolution2d_descriptor_default_values():
97    desc = ann.Convolution2dDescriptor()
98    assert desc.m_PadLeft == 0
99    assert desc.m_PadTop == 0
100    assert desc.m_PadRight == 0
101    assert desc.m_PadBottom == 0
102    assert desc.m_StrideX == 1
103    assert desc.m_StrideY == 1
104    assert desc.m_DilationX == 1
105    assert desc.m_DilationY == 1
106    assert desc.m_BiasEnabled == False
107    assert desc.m_DataLayout == ann.DataLayout_NCHW
108
109def test_convolution3d_descriptor_default_values():
110    desc = ann.Convolution3dDescriptor()
111    assert desc.m_PadLeft == 0
112    assert desc.m_PadTop == 0
113    assert desc.m_PadRight == 0
114    assert desc.m_PadBottom == 0
115    assert desc.m_PadFront == 0
116    assert desc.m_PadBack == 0
117    assert desc.m_StrideX == 1
118    assert desc.m_StrideY == 1
119    assert desc.m_StrideZ == 1
120    assert desc.m_DilationX == 1
121    assert desc.m_DilationY == 1
122    assert desc.m_DilationZ == 1
123    assert desc.m_BiasEnabled == False
124    assert desc.m_DataLayout == ann.DataLayout_NDHWC
125
126
127def test_depthtospace_descriptor_default_values():
128    desc = ann.DepthToSpaceDescriptor()
129    assert desc.m_BlockSize == 1
130    assert desc.m_DataLayout == ann.DataLayout_NHWC
131
132
133def test_depthwise_convolution2d_descriptor_default_values():
134    desc = ann.DepthwiseConvolution2dDescriptor()
135    assert desc.m_PadLeft == 0
136    assert desc.m_PadTop == 0
137    assert desc.m_PadRight == 0
138    assert desc.m_PadBottom == 0
139    assert desc.m_StrideX == 1
140    assert desc.m_StrideY == 1
141    assert desc.m_DilationX == 1
142    assert desc.m_DilationY == 1
143    assert desc.m_BiasEnabled == False
144    assert desc.m_DataLayout == ann.DataLayout_NCHW
145
146
147def test_detectionpostprocess_descriptor_default_values():
148    desc = ann.DetectionPostProcessDescriptor()
149    assert desc.m_MaxDetections == 0
150    assert desc.m_MaxClassesPerDetection == 1
151    assert desc.m_DetectionsPerClass == 1
152    assert desc.m_NmsScoreThreshold == 0
153    assert desc.m_NmsIouThreshold == 0
154    assert desc.m_NumClasses == 0
155    assert desc.m_UseRegularNms == False
156    assert desc.m_ScaleH == 0
157    assert desc.m_ScaleW == 0
158    assert desc.m_ScaleX == 0
159    assert desc.m_ScaleY == 0
160
161
162def test_fakequantization_descriptor_default_values():
163    desc = ann.FakeQuantizationDescriptor()
164    np.allclose(6, desc.m_Max)
165    np.allclose(-6, desc.m_Min)
166
167
168def test_fill_descriptor_default_values():
169    desc = ann.FillDescriptor()
170    np.allclose(0, desc.m_Value)
171
172
173def test_gather_descriptor_default_values():
174    desc = ann.GatherDescriptor()
175    assert desc.m_Axis == 0
176
177
178def test_fully_connected_descriptor_default_values():
179    desc = ann.FullyConnectedDescriptor()
180    assert desc.m_BiasEnabled == False
181    assert desc.m_TransposeWeightMatrix == False
182
183
184def test_instancenormalization_descriptor_default_values():
185    desc = ann.InstanceNormalizationDescriptor()
186    assert desc.m_Gamma == 1
187    assert desc.m_Beta == 0
188    assert desc.m_DataLayout == ann.DataLayout_NCHW
189    np.allclose(1e-12, desc.m_Eps)
190
191
192def test_lstm_descriptor_default_values():
193    desc = ann.LstmDescriptor()
194    assert desc.m_ActivationFunc == 1
195    assert desc.m_ClippingThresCell == 0
196    assert desc.m_ClippingThresProj == 0
197    assert desc.m_CifgEnabled == True
198    assert desc.m_PeepholeEnabled == False
199    assert desc.m_ProjectionEnabled == False
200    assert desc.m_LayerNormEnabled == False
201
202
203def test_l2normalization_descriptor_default_values():
204    desc = ann.L2NormalizationDescriptor()
205    assert desc.m_DataLayout == ann.DataLayout_NCHW
206    np.allclose(1e-12, desc.m_Eps)
207
208
209def test_mean_descriptor_default_values():
210    desc = ann.MeanDescriptor()
211    assert desc.m_KeepDims == False
212
213
214def test_normalization_descriptor_default_values():
215    desc = ann.NormalizationDescriptor()
216    assert desc.m_NormChannelType == ann.NormalizationAlgorithmChannel_Across
217    assert desc.m_NormMethodType == ann.NormalizationAlgorithmMethod_LocalBrightness
218    assert desc.m_NormSize == 0
219    assert desc.m_Alpha == 0
220    assert desc.m_Beta == 0
221    assert desc.m_K == 0
222    assert desc.m_DataLayout == ann.DataLayout_NCHW
223
224
225def test_origin_descriptor_default_values():
226    desc = ann.ConcatDescriptor()
227    assert 0 == desc.GetNumViews()
228    assert 0 == desc.GetNumDimensions()
229    assert 1 == desc.GetConcatAxis()
230
231
232def test_origin_descriptor_incorrect_views():
233    desc = ann.ConcatDescriptor(2, 2)
234    with pytest.raises(RuntimeError) as err:
235        desc.SetViewOriginCoord(1000, 100, 1000)
236    assert "Failed to set view origin coordinates." in str(err.value)
237
238
239def test_origin_descriptor_ctor():
240    desc = ann.ConcatDescriptor(2, 2)
241    value = 5
242    for i in range(desc.GetNumViews()):
243        for j in range(desc.GetNumDimensions()):
244            desc.SetViewOriginCoord(i, j, value+i)
245    desc.SetConcatAxis(1)
246
247    assert 2 == desc.GetNumViews()
248    assert 2 == desc.GetNumDimensions()
249    assert [5, 5] == desc.GetViewOrigin(0)
250    assert [6, 6] == desc.GetViewOrigin(1)
251    assert 1 == desc.GetConcatAxis()
252
253
254def test_pad_descriptor_default_values():
255    desc = ann.PadDescriptor()
256    assert desc.m_PadValue == 0
257    assert desc.m_PaddingMode == ann.PaddingMode_Constant
258
259
260def test_permute_descriptor_default_values():
261    pv = ann.PermutationVector((0, 2, 3, 1))
262    desc = ann.PermuteDescriptor(pv)
263    assert desc.m_DimMappings.GetSize() == 4
264    assert desc.m_DimMappings[0] == 0
265    assert desc.m_DimMappings[1] == 2
266    assert desc.m_DimMappings[2] == 3
267    assert desc.m_DimMappings[3] == 1
268
269
270def test_pooling_descriptor_default_values():
271    desc = ann.Pooling2dDescriptor()
272    assert desc.m_PoolType == ann.PoolingAlgorithm_Max
273    assert desc.m_PadLeft == 0
274    assert desc.m_PadTop == 0
275    assert desc.m_PadRight == 0
276    assert desc.m_PadBottom == 0
277    assert desc.m_PoolHeight == 0
278    assert desc.m_PoolWidth == 0
279    assert desc.m_StrideX == 0
280    assert desc.m_StrideY == 0
281    assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
282    assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
283    assert desc.m_DataLayout == ann.DataLayout_NCHW
284
285def test_pooling_3d_descriptor_default_values():
286    desc = ann.Pooling3dDescriptor()
287    assert desc.m_PoolType == ann.PoolingAlgorithm_Max
288    assert desc.m_PadLeft == 0
289    assert desc.m_PadTop == 0
290    assert desc.m_PadRight == 0
291    assert desc.m_PadBottom == 0
292    assert desc.m_PadFront == 0
293    assert desc.m_PadBack == 0
294    assert desc.m_PoolHeight == 0
295    assert desc.m_PoolWidth == 0
296    assert desc.m_StrideX == 0
297    assert desc.m_StrideY == 0
298    assert desc.m_StrideZ == 0
299    assert desc.m_OutputShapeRounding == ann.OutputShapeRounding_Floor
300    assert desc.m_PaddingMethod == ann.PaddingMethod_Exclude
301    assert desc.m_DataLayout == ann.DataLayout_NCDHW
302
303
304def test_reshape_descriptor_default_values():
305    desc = ann.ReshapeDescriptor()
306    # check the empty Targetshape
307    assert desc.m_TargetShape.GetNumDimensions() == 0
308
309def test_reduce_descriptor_default_values():
310    desc = ann.ReduceDescriptor()
311    assert desc.m_KeepDims == False
312    assert desc.m_vAxis == []
313    assert desc.m_ReduceOperation == ann.ReduceOperation_Sum
314
315def test_slice_descriptor_default_values():
316    desc = ann.SliceDescriptor()
317    assert desc.m_TargetWidth == 0
318    assert desc.m_TargetHeight == 0
319    assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
320    assert desc.m_DataLayout == ann.DataLayout_NCHW
321
322
323def test_resize_descriptor_default_values():
324    desc = ann.ResizeDescriptor()
325    assert desc.m_TargetWidth == 0
326    assert desc.m_TargetHeight == 0
327    assert desc.m_Method == ann.ResizeMethod_NearestNeighbor
328    assert desc.m_DataLayout == ann.DataLayout_NCHW
329    assert desc.m_AlignCorners == False
330
331
332def test_spacetobatchnd_descriptor_default_values():
333    desc = ann.SpaceToBatchNdDescriptor()
334    assert desc.m_DataLayout == ann.DataLayout_NCHW
335
336
337def test_spacetodepth_descriptor_default_values():
338    desc = ann.SpaceToDepthDescriptor()
339    assert desc.m_BlockSize == 1
340    assert desc.m_DataLayout == ann.DataLayout_NHWC
341
342
343def test_stack_descriptor_default_values():
344    desc = ann.StackDescriptor()
345    assert desc.m_Axis == 0
346    assert desc.m_NumInputs == 0
347    # check the empty Inputshape
348    assert desc.m_InputShape.GetNumDimensions() == 0
349
350
351def test_slice_descriptor_default_values():
352    desc = ann.SliceDescriptor()
353    desc.m_Begin = [1, 2, 3, 4, 5]
354    desc.m_Size = (1, 2, 3, 4)
355
356    assert [1, 2, 3, 4, 5] == desc.m_Begin
357    assert [1, 2, 3, 4] == desc.m_Size
358
359
360def test_slice_descriptor_ctor():
361    desc = ann.SliceDescriptor([1, 2, 3, 4, 5], (1, 2, 3, 4))
362
363    assert [1, 2, 3, 4, 5] == desc.m_Begin
364    assert [1, 2, 3, 4] == desc.m_Size
365
366
367def test_strided_slice_descriptor_default_values():
368    desc = ann.StridedSliceDescriptor()
369    desc.m_Begin = [1, 2, 3, 4, 5]
370    desc.m_End = [6, 7, 8, 9, 10]
371    desc.m_Stride = (10, 10)
372    desc.m_BeginMask = 1
373    desc.m_EndMask = 2
374    desc.m_ShrinkAxisMask = 3
375    desc.m_EllipsisMask = 4
376    desc.m_NewAxisMask = 5
377
378    assert [1, 2, 3, 4, 5] == desc.m_Begin
379    assert [6, 7, 8, 9, 10] == desc.m_End
380    assert [10, 10] == desc.m_Stride
381    assert 1 == desc.m_BeginMask
382    assert 2 == desc.m_EndMask
383    assert 3 == desc.m_ShrinkAxisMask
384    assert 4 == desc.m_EllipsisMask
385    assert 5 == desc.m_NewAxisMask
386
387
388def test_strided_slice_descriptor_ctor():
389    desc = ann.StridedSliceDescriptor([1, 2, 3, 4, 5], [6, 7, 8, 9, 10], (10, 10))
390    desc.m_Begin = [1, 2, 3, 4, 5]
391    desc.m_End = [6, 7, 8, 9, 10]
392    desc.m_Stride = (10, 10)
393
394    assert [1, 2, 3, 4, 5] == desc.m_Begin
395    assert [6, 7, 8, 9, 10] == desc.m_End
396    assert [10, 10] == desc.m_Stride
397
398
399def test_softmax_descriptor_default_values():
400    desc = ann.SoftmaxDescriptor()
401    assert desc.m_Axis == -1
402    np.allclose(1.0, desc.m_Beta)
403
404
405def test_space_to_batch_nd_descriptor_default_values():
406    desc = ann.SpaceToBatchNdDescriptor()
407    assert [1, 1] == desc.m_BlockShape
408    assert [(0, 0), (0, 0)] == desc.m_PadList
409    assert ann.DataLayout_NCHW == desc.m_DataLayout
410
411
412def test_space_to_batch_nd_descriptor_assigned_values():
413    desc = ann.SpaceToBatchNdDescriptor()
414    desc.m_BlockShape = (90, 100)
415    desc.m_PadList = [(1, 2), (3, 4)]
416    assert [90, 100] == desc.m_BlockShape
417    assert [(1, 2), (3, 4)] == desc.m_PadList
418    assert ann.DataLayout_NCHW == desc.m_DataLayout
419
420
421def test_space_to_batch_nd_descriptor_ctor():
422    desc = ann.SpaceToBatchNdDescriptor((1, 2, 3), [(1, 2), (3, 4)])
423    assert [1, 2, 3] == desc.m_BlockShape
424    assert [(1, 2), (3, 4)] == desc.m_PadList
425    assert ann.DataLayout_NCHW == desc.m_DataLayout
426
427
428def test_transpose_convolution2d_descriptor_default_values():
429    desc = ann.TransposeConvolution2dDescriptor()
430    assert desc.m_PadLeft == 0
431    assert desc.m_PadTop == 0
432    assert desc.m_PadRight == 0
433    assert desc.m_PadBottom == 0
434    assert desc.m_StrideX == 0
435    assert desc.m_StrideY == 0
436    assert desc.m_BiasEnabled == False
437    assert desc.m_DataLayout == ann.DataLayout_NCHW
438    assert desc.m_OutputShapeEnabled == False
439
440def test_transpose_descriptor_default_values():
441    pv = ann.PermutationVector((0, 3, 2, 1, 4))
442    desc = ann.TransposeDescriptor(pv)
443    assert desc.m_DimMappings.GetSize() == 5
444    assert desc.m_DimMappings[0] == 0
445    assert desc.m_DimMappings[1] == 3
446    assert desc.m_DimMappings[2] == 2
447    assert desc.m_DimMappings[3] == 1
448    assert desc.m_DimMappings[4] == 4
449
450def test_view_descriptor_default_values():
451    desc = ann.SplitterDescriptor()
452    assert 0 == desc.GetNumViews()
453    assert 0 == desc.GetNumDimensions()
454
455
456def test_elementwise_unary_descriptor_default_values():
457    desc = ann.ElementwiseUnaryDescriptor()
458    assert desc.m_Operation == ann.UnaryOperation_Abs
459
460
461def test_logical_binary_descriptor_default_values():
462    desc = ann.LogicalBinaryDescriptor()
463    assert desc.m_Operation == ann.LogicalBinaryOperation_LogicalAnd
464
465def test_view_descriptor_incorrect_input():
466    desc = ann.SplitterDescriptor(2, 3)
467    with pytest.raises(RuntimeError) as err:
468        desc.SetViewOriginCoord(1000, 100, 1000)
469    assert "Failed to set view origin coordinates." in str(err.value)
470
471    with pytest.raises(RuntimeError) as err:
472        desc.SetViewSize(1000, 100, 1000)
473    assert "Failed to set view size." in str(err.value)
474
475
476def test_view_descriptor_ctor():
477    desc = ann.SplitterDescriptor(2, 3)
478    value_size = 1
479    value_orig_coord = 5
480    for i in range(desc.GetNumViews()):
481        for j in range(desc.GetNumDimensions()):
482            desc.SetViewOriginCoord(i, j, value_orig_coord+i)
483            desc.SetViewSize(i, j, value_size+i)
484
485    assert 2 == desc.GetNumViews()
486    assert 3 == desc.GetNumDimensions()
487    assert [5, 5] == desc.GetViewOrigin(0)
488    assert [6, 6] == desc.GetViewOrigin(1)
489    assert [1, 1] == desc.GetViewSizes(0)
490    assert [2, 2] == desc.GetViewSizes(1)
491
492
493def test_createdescriptorforconcatenation_ctor():
494    input_shape_vector = [ann.TensorShape((2, 1)), ann.TensorShape((3, 1)), ann.TensorShape((4, 1))]
495    desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
496    assert 3 == desc.GetNumViews()
497    assert 0 == desc.GetConcatAxis()
498    assert 2 == desc.GetNumDimensions()
499    c = desc.GetViewOrigin(1)
500    d = desc.GetViewOrigin(0)
501
502
503def test_createdescriptorforconcatenation_wrong_shape_for_axis():
504    input_shape_vector = [ann.TensorShape((1, 2)), ann.TensorShape((3, 4)), ann.TensorShape((5, 6))]
505    with pytest.raises(RuntimeError) as err:
506        desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
507
508    assert "All inputs to concatenation must be the same size along all dimensions  except the concatenation dimension" in str(
509        err.value)
510
511
512@pytest.mark.parametrize("input_shape_vector", [([-1, "one"]),
513                                                ([1.33, 4.55]),
514                                                ([{1: "one"}])], ids=lambda x: str(x))
515def test_createdescriptorforconcatenation_rubbish_assignment_shape_vector(input_shape_vector):
516    with pytest.raises(TypeError) as err:
517        desc = ann.CreateDescriptorForConcatenation(input_shape_vector, 0)
518
519    assert "in method 'CreateDescriptorForConcatenation', argument 1 of type 'std::vector< armnn::TensorShape,std::allocator< armnn::TensorShape > >'" in str(
520        err.value)
521
522
523generated_classes = inspect.getmembers(generated, inspect.isclass)
524generated_classes_names = list(map(lambda x: x[0], generated_classes))
525@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
526                                       'ArgMinMaxDescriptor',
527                                       'PermuteDescriptor',
528                                       'SoftmaxDescriptor',
529                                       'ConcatDescriptor',
530                                       'SplitterDescriptor',
531                                       'Pooling2dDescriptor',
532                                       'FullyConnectedDescriptor',
533                                       'Convolution2dDescriptor',
534                                       'Convolution3dDescriptor',
535                                       'DepthwiseConvolution2dDescriptor',
536                                       'DetectionPostProcessDescriptor',
537                                       'NormalizationDescriptor',
538                                       'L2NormalizationDescriptor',
539                                       'BatchNormalizationDescriptor',
540                                       'InstanceNormalizationDescriptor',
541                                       'BatchToSpaceNdDescriptor',
542                                       'FakeQuantizationDescriptor',
543                                       'ReduceDescriptor',
544                                       'ResizeDescriptor',
545                                       'ReshapeDescriptor',
546                                       'SpaceToBatchNdDescriptor',
547                                       'SpaceToDepthDescriptor',
548                                       'LstmDescriptor',
549                                       'MeanDescriptor',
550                                       'PadDescriptor',
551                                       'SliceDescriptor',
552                                       'StackDescriptor',
553                                       'StridedSliceDescriptor',
554                                       'TransposeConvolution2dDescriptor',
555                                       'TransposeDescriptor',
556                                       'ElementwiseUnaryDescriptor',
557                                       'FillDescriptor',
558                                       'GatherDescriptor',
559                                       'LogicalBinaryDescriptor',
560                                       'ChannelShuffleDescriptor'])
561class TestDescriptorMassChecks:
562
563    def test_desc_implemented(self, desc_name):
564        assert desc_name in generated_classes_names
565
566    def test_desc_equal(self, desc_name):
567        desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
568
569        assert desc_class() == desc_class()
570
571
572generated_classes = inspect.getmembers(generated, inspect.isclass)
573generated_classes_names = list(map(lambda x: x[0], generated_classes))
574@pytest.mark.parametrize("desc_name", ['ActivationDescriptor',
575                                       'ArgMinMaxDescriptor',
576                                       'PermuteDescriptor',
577                                       'SoftmaxDescriptor',
578                                       'ConcatDescriptor',
579                                       'SplitterDescriptor',
580                                       'Pooling2dDescriptor',
581                                       'FullyConnectedDescriptor',
582                                       'Convolution2dDescriptor',
583                                       'Convolution3dDescriptor',
584                                       'DepthwiseConvolution2dDescriptor',
585                                       'DetectionPostProcessDescriptor',
586                                       'NormalizationDescriptor',
587                                       'L2NormalizationDescriptor',
588                                       'BatchNormalizationDescriptor',
589                                       'InstanceNormalizationDescriptor',
590                                       'BatchToSpaceNdDescriptor',
591                                       'FakeQuantizationDescriptor',
592                                       'ReduceDescriptor',
593                                       'ResizeDescriptor',
594                                       'ReshapeDescriptor',
595                                       'SpaceToBatchNdDescriptor',
596                                       'SpaceToDepthDescriptor',
597                                       'LstmDescriptor',
598                                       'MeanDescriptor',
599                                       'PadDescriptor',
600                                       'SliceDescriptor',
601                                       'StackDescriptor',
602                                       'StridedSliceDescriptor',
603                                       'TransposeConvolution2dDescriptor',
604                                       'TransposeDescriptor',
605                                       'ElementwiseUnaryDescriptor',
606                                       'FillDescriptor',
607                                       'GatherDescriptor',
608                                       'LogicalBinaryDescriptor',
609                                       'ChannelShuffleDescriptor'])
610class TestDescriptorMassChecks:
611
612    def test_desc_implemented(self, desc_name):
613        assert desc_name in generated_classes_names
614
615    def test_desc_equal(self, desc_name):
616        desc_class = next(filter(lambda x: x[0] == desc_name, generated_classes))[1]
617
618        assert desc_class() == desc_class()
619
620