xref: /aosp_15_r20/external/armnn/src/backends/tosaCommon/test/AvgPool2DIgnoreValueChecker.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "TosaTestUtils.hpp"
7 
8 using namespace armnn;
9 using namespace tosa;
10 
VerifyAvgPool2DIgnoreValue(TosaSerializationBasicBlock * basicBlock,std::vector<std::vector<int32_t>> inputShape,std::vector<std::vector<int32_t>> outputShape,std::vector<std::vector<int32_t>> intermediateShape,const BaseDescriptor & descriptor,DType dataType=DType_FP32)11 void VerifyAvgPool2DIgnoreValue(TosaSerializationBasicBlock* basicBlock,
12                                 std::vector<std::vector<int32_t>> inputShape,
13                                 std::vector<std::vector<int32_t>> outputShape,
14                                 std::vector<std::vector<int32_t>> intermediateShape,
15                                 const BaseDescriptor& descriptor,
16                                 DType dataType = DType_FP32)
17 {
18     uint32_t numInputs = static_cast<uint32_t>(inputShape.size());
19     uint32_t numOutputs = static_cast<uint32_t>(outputShape.size());
20 
21     std::string blockStr = TosaOpToString(Op_AVG_POOL2D) + "_block_";
22     CHECK(basicBlock->GetName().find(blockStr)  != std::string::npos);
23     CHECK(basicBlock->GetInputs().size() == numInputs);
24     CHECK(basicBlock->GetOutputs().size() == numOutputs);
25     CHECK(basicBlock->GetOperators().size() == 2);
26     CHECK(basicBlock->GetTensors().size() == 3);
27 
28     //
29     // Verify padding operator first.
30     //
31 
32     TosaSerializationOperator* padOp = basicBlock->GetOperators().at(0);
33     uint32_t padOpOutputs = 1;
34     CHECK(padOp->GetInputTensorNames().size() == numInputs);
35     CHECK(padOp->GetOutputTensorNames().size() == padOpOutputs);
36 
37     for (uint32_t i = 0; i < numInputs; i++)
38     {
39         std::basic_string<char> blockInputName = basicBlock->GetInputs()[i];
40         std::basic_string<char> operatorInputName  = padOp->GetInputTensorNames()[i];
41 
42         std::string opStr = "input" + std::to_string(i) + "_";
43 
44         CHECK(blockInputName == operatorInputName);
45         CHECK(basicBlock->GetTensorByName(blockInputName));
46         CHECK(blockInputName.find(opStr)  != std::string::npos);
47 
48         TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorInputName);
49         CHECK(tensor->GetDtype() == dataType);
50         CHECK(tensor->GetData().size() == 0);
51         CHECK(tensor->GetShape() == inputShape[static_cast<unsigned long int>(i)]);
52     }
53 
54     for (uint32_t i = 0; i < padOpOutputs; i++)
55     {
56         std::basic_string<char> operatorOutputName  = padOp->GetOutputTensorNames()[i];
57         std::string opStr = "intermediate" + std::to_string(i) + "_";
58 
59         CHECK(basicBlock->GetTensorByName(operatorOutputName));
60         CHECK(operatorOutputName.find(opStr)  != std::string::npos);
61 
62         TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorOutputName);
63         CHECK(tensor->GetDtype() == dataType);
64         CHECK(tensor->GetData().size() == 0);
65         CHECK(tensor->GetShape() == intermediateShape[static_cast<unsigned long int>(i)]);
66     }
67 
68     CHECK(padOp->GetAttributeType() == Attribute_PadAttribute);
69     CHECK(padOp->GetOp() == Op_PAD);
70 
71     VerifyTosaAttribute(descriptor,
72                         padOp->GetAttribute(),
73                         inputShape[0],
74                         outputShape[0],
75                         LayerType::Pooling2d);
76 
77     //
78     // Verify average pool operator second.
79     //
80 
81     TosaSerializationOperator* poolOp = basicBlock->GetOperators().at(1);
82     uint32_t poolOpInputs = 1;
83     CHECK(poolOp->GetInputTensorNames().size() == poolOpInputs);
84     CHECK(poolOp->GetOutputTensorNames().size() == numOutputs);
85 
86     for (uint32_t i = 0; i < poolOpInputs; i++)
87     {
88         std::basic_string<char> operatorInputName  = poolOp->GetInputTensorNames()[i];
89         std::string opStr = "intermediate" + std::to_string(i) + "_";
90 
91         CHECK(basicBlock->GetTensorByName(operatorInputName));
92         CHECK(operatorInputName.find(opStr)  != std::string::npos);
93 
94         TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorInputName);
95         CHECK(tensor->GetDtype() == dataType);
96         CHECK(tensor->GetData().size() == 0);
97         CHECK(tensor->GetShape() == intermediateShape[static_cast<unsigned long int>(i)]);
98     }
99 
100     for (uint32_t i = 0; i < numOutputs; i++)
101     {
102         std::basic_string<char> blockOutputName = basicBlock->GetOutputs()[i];
103         std::basic_string<char> operatorOutputName  = poolOp->GetOutputTensorNames()[i];
104 
105         std::string opStr = "output" + std::to_string(i) + "_";
106 
107         CHECK(blockOutputName == operatorOutputName);
108         CHECK(basicBlock->GetTensorByName(blockOutputName));
109         CHECK(blockOutputName.find(opStr)  != std::string::npos);
110 
111         TosaSerializationTensor* tensor = basicBlock->GetTensorByName(operatorOutputName);
112         CHECK(tensor->GetDtype() == dataType);
113         CHECK(tensor->GetData().size() == 0);
114         CHECK(tensor->GetShape() == outputShape[static_cast<unsigned long int>(i)]);
115     }
116 
117     CHECK(poolOp->GetAttributeType() == Attribute_PoolAttribute);
118     CHECK(poolOp->GetOp() == Op_AVG_POOL2D);
119 
120     VerifyTosaAttribute(descriptor,
121                         poolOp->GetAttribute(),
122                         inputShape[0],
123                         outputShape[0],
124                         LayerType::Pooling2d,
125                         1);
126 
127 }