xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/SerializerTestUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "SerializerTestUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker using armnnDeserializer::IDeserializer;
12*89c4ff92SAndroid Build Coastguard Worker 
LayerVerifierBase(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)13*89c4ff92SAndroid Build Coastguard Worker LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14*89c4ff92SAndroid Build Coastguard Worker                                      const std::vector<armnn::TensorInfo>& inputInfos,
15*89c4ff92SAndroid Build Coastguard Worker                                      const std::vector<armnn::TensorInfo>& outputInfos)
16*89c4ff92SAndroid Build Coastguard Worker                                      : m_LayerName(layerName)
17*89c4ff92SAndroid Build Coastguard Worker                                      , m_InputTensorInfos(inputInfos)
18*89c4ff92SAndroid Build Coastguard Worker                                      , m_OutputTensorInfos(outputInfos)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id)21*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22*89c4ff92SAndroid Build Coastguard Worker                      const armnn::BaseDescriptor& descriptor,
23*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<armnn::ConstTensor>& constants,
24*89c4ff92SAndroid Build Coastguard Worker                      const char* name,
25*89c4ff92SAndroid Build Coastguard Worker                      const armnn::LayerBindingId id)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker     armnn::IgnoreUnused(descriptor, constants, id);
28*89c4ff92SAndroid Build Coastguard Worker     switch (layer->GetType())
29*89c4ff92SAndroid Build Coastguard Worker     {
30*89c4ff92SAndroid Build Coastguard Worker         case armnn::LayerType::Input: break;
31*89c4ff92SAndroid Build Coastguard Worker         case armnn::LayerType::Output: break;
32*89c4ff92SAndroid Build Coastguard Worker         default:
33*89c4ff92SAndroid Build Coastguard Worker         {
34*89c4ff92SAndroid Build Coastguard Worker             VerifyNameAndConnections(layer, name);
35*89c4ff92SAndroid Build Coastguard Worker         }
36*89c4ff92SAndroid Build Coastguard Worker     }
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker 
VerifyNameAndConnections(const armnn::IConnectableLayer * layer,const char * name)40*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::string(name) == m_LayerName.c_str());
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48*89c4ff92SAndroid Build Coastguard Worker     {
49*89c4ff92SAndroid Build Coastguard Worker         const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
50*89c4ff92SAndroid Build Coastguard Worker         CHECK(connectedOutput);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
53*89c4ff92SAndroid Build Coastguard Worker         CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
54*89c4ff92SAndroid Build Coastguard Worker         CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker         if (connectedInfo.HasMultipleQuantizationScales())
57*89c4ff92SAndroid Build Coastguard Worker         {
58*89c4ff92SAndroid Build Coastguard Worker             CHECK(connectedInfo.GetQuantizationScales() == m_InputTensorInfos[i].GetQuantizationScales());
59*89c4ff92SAndroid Build Coastguard Worker         }
60*89c4ff92SAndroid Build Coastguard Worker         else
61*89c4ff92SAndroid Build Coastguard Worker         {
62*89c4ff92SAndroid Build Coastguard Worker             CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
63*89c4ff92SAndroid Build Coastguard Worker         }
64*89c4ff92SAndroid Build Coastguard Worker         CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
65*89c4ff92SAndroid Build Coastguard Worker     }
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
70*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
71*89c4ff92SAndroid Build Coastguard Worker         CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
74*89c4ff92SAndroid Build Coastguard Worker         CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
VerifyConstTensors(const std::string & tensorName,const armnn::ConstTensor * expectedPtr,const armnn::ConstTensor * actualPtr)78*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
79*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::ConstTensor* expectedPtr,
80*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::ConstTensor* actualPtr)
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker     if (expectedPtr == nullptr)
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(actualPtr == nullptr, (tensorName + " should not exist"));
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker     else
87*89c4ff92SAndroid Build Coastguard Worker     {
88*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(actualPtr != nullptr, (tensorName + " should have been set"));
89*89c4ff92SAndroid Build Coastguard Worker         if (actualPtr != nullptr)
90*89c4ff92SAndroid Build Coastguard Worker         {
91*89c4ff92SAndroid Build Coastguard Worker             const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
92*89c4ff92SAndroid Build Coastguard Worker             const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker             CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
95*89c4ff92SAndroid Build Coastguard Worker                           (tensorName + " shapes don't match"));
96*89c4ff92SAndroid Build Coastguard Worker             CHECK_MESSAGE(
97*89c4ff92SAndroid Build Coastguard Worker                     GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
98*89c4ff92SAndroid Build Coastguard Worker                     (tensorName + " data types don't match"));
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker             CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
101*89c4ff92SAndroid Build Coastguard Worker                           (tensorName + " (GetNumBytes) data sizes do not match"));
102*89c4ff92SAndroid Build Coastguard Worker             if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
103*89c4ff92SAndroid Build Coastguard Worker             {
104*89c4ff92SAndroid Build Coastguard Worker                 //check the data is identical
105*89c4ff92SAndroid Build Coastguard Worker                 const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
106*89c4ff92SAndroid Build Coastguard Worker                 const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
107*89c4ff92SAndroid Build Coastguard Worker                 bool same = true;
108*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
109*89c4ff92SAndroid Build Coastguard Worker                 {
110*89c4ff92SAndroid Build Coastguard Worker                     same = expectedData[i] == actualData[i];
111*89c4ff92SAndroid Build Coastguard Worker                     if (!same)
112*89c4ff92SAndroid Build Coastguard Worker                     {
113*89c4ff92SAndroid Build Coastguard Worker                         break;
114*89c4ff92SAndroid Build Coastguard Worker                     }
115*89c4ff92SAndroid Build Coastguard Worker                 }
116*89c4ff92SAndroid Build Coastguard Worker                 CHECK_MESSAGE(same, (tensorName + " data does not match"));
117*89c4ff92SAndroid Build Coastguard Worker             }
118*89c4ff92SAndroid Build Coastguard Worker         }
119*89c4ff92SAndroid Build Coastguard Worker     }
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker 
CompareConstTensor(const armnn::ConstTensor & tensor1,const armnn::ConstTensor & tensor2)122*89c4ff92SAndroid Build Coastguard Worker void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker     CHECK(tensor1.GetShape() == tensor2.GetShape());
125*89c4ff92SAndroid Build Coastguard Worker     CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     switch (tensor1.GetDataType())
128*89c4ff92SAndroid Build Coastguard Worker     {
129*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float32:
130*89c4ff92SAndroid Build Coastguard Worker             CompareConstTensorData<const float*>(
131*89c4ff92SAndroid Build Coastguard Worker                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132*89c4ff92SAndroid Build Coastguard Worker             break;
133*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmU8:
134*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Boolean:
135*89c4ff92SAndroid Build Coastguard Worker             CompareConstTensorData<const uint8_t*>(
136*89c4ff92SAndroid Build Coastguard Worker                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137*89c4ff92SAndroid Build Coastguard Worker             break;
138*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS8:
139*89c4ff92SAndroid Build Coastguard Worker             CompareConstTensorData<const int8_t*>(
140*89c4ff92SAndroid Build Coastguard Worker                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141*89c4ff92SAndroid Build Coastguard Worker             break;
142*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Signed32:
143*89c4ff92SAndroid Build Coastguard Worker             CompareConstTensorData<const int32_t*>(
144*89c4ff92SAndroid Build Coastguard Worker                 tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145*89c4ff92SAndroid Build Coastguard Worker             break;
146*89c4ff92SAndroid Build Coastguard Worker         default:
147*89c4ff92SAndroid Build Coastguard Worker             // Note that Float16 is not yet implemented
148*89c4ff92SAndroid Build Coastguard Worker             MESSAGE("Unexpected datatype");
149*89c4ff92SAndroid Build Coastguard Worker             CHECK(false);
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
DeserializeNetwork(const std::string & serializerString)153*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156*89c4ff92SAndroid Build Coastguard Worker     return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker 
SerializeNetwork(const armnn::INetwork & network)159*89c4ff92SAndroid Build Coastguard Worker std::string SerializeNetwork(const armnn::INetwork& network)
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker     armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker     serializer->Serialize(network);
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker     std::stringstream stream;
166*89c4ff92SAndroid Build Coastguard Worker     serializer->SaveSerializedToStream(stream);
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     std::string serializerString{stream.str()};
169*89c4ff92SAndroid Build Coastguard Worker     return serializerString;
170*89c4ff92SAndroid Build Coastguard Worker }
171