1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "SerializerTestUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker using armnnDeserializer::IDeserializer;
12*89c4ff92SAndroid Build Coastguard Worker
LayerVerifierBase(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos)13*89c4ff92SAndroid Build Coastguard Worker LayerVerifierBase::LayerVerifierBase(const std::string& layerName,
14*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorInfo>& inputInfos,
15*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorInfo>& outputInfos)
16*89c4ff92SAndroid Build Coastguard Worker : m_LayerName(layerName)
17*89c4ff92SAndroid Build Coastguard Worker , m_InputTensorInfos(inputInfos)
18*89c4ff92SAndroid Build Coastguard Worker , m_OutputTensorInfos(outputInfos)
19*89c4ff92SAndroid Build Coastguard Worker {}
20*89c4ff92SAndroid Build Coastguard Worker
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id)21*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::ExecuteStrategy(const armnn::IConnectableLayer* layer,
22*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& descriptor,
23*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::ConstTensor>& constants,
24*89c4ff92SAndroid Build Coastguard Worker const char* name,
25*89c4ff92SAndroid Build Coastguard Worker const armnn::LayerBindingId id)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(descriptor, constants, id);
28*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType())
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Input: break;
31*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Output: break;
32*89c4ff92SAndroid Build Coastguard Worker default:
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker VerifyNameAndConnections(layer, name);
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
39*89c4ff92SAndroid Build Coastguard Worker
VerifyNameAndConnections(const armnn::IConnectableLayer * layer,const char * name)40*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name)
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker CHECK(std::string(name) == m_LayerName.c_str());
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker CHECK(layer->GetNumInputSlots() == m_InputTensorInfos.size());
45*89c4ff92SAndroid Build Coastguard Worker CHECK(layer->GetNumOutputSlots() == m_OutputTensorInfos.size());
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_InputTensorInfos.size(); i++)
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection();
50*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedOutput);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo();
53*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape());
54*89c4ff92SAndroid Build Coastguard Worker CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType()));
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker if (connectedInfo.HasMultipleQuantizationScales())
57*89c4ff92SAndroid Build Coastguard Worker {
58*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedInfo.GetQuantizationScales() == m_InputTensorInfos[i].GetQuantizationScales());
59*89c4ff92SAndroid Build Coastguard Worker }
60*89c4ff92SAndroid Build Coastguard Worker else
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedInfo.GetQuantizationScale() == m_InputTensorInfos[i].GetQuantizationScale());
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker CHECK(connectedInfo.GetQuantizationOffset() == m_InputTensorInfos[i].GetQuantizationOffset());
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < m_OutputTensorInfos.size(); i++)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo = layer->GetOutputSlot(i).GetTensorInfo();
70*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetShape() == m_OutputTensorInfos[i].GetShape());
71*89c4ff92SAndroid Build Coastguard Worker CHECK(GetDataTypeName(outputInfo.GetDataType()) == GetDataTypeName(m_OutputTensorInfos[i].GetDataType()));
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetQuantizationScale() == m_OutputTensorInfos[i].GetQuantizationScale());
74*89c4ff92SAndroid Build Coastguard Worker CHECK(outputInfo.GetQuantizationOffset() == m_OutputTensorInfos[i].GetQuantizationOffset());
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker
VerifyConstTensors(const std::string & tensorName,const armnn::ConstTensor * expectedPtr,const armnn::ConstTensor * actualPtr)78*89c4ff92SAndroid Build Coastguard Worker void LayerVerifierBase::VerifyConstTensors(const std::string& tensorName,
79*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor* expectedPtr,
80*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor* actualPtr)
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker if (expectedPtr == nullptr)
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(actualPtr == nullptr, (tensorName + " should not exist"));
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker else
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(actualPtr != nullptr, (tensorName + " should have been set"));
89*89c4ff92SAndroid Build Coastguard Worker if (actualPtr != nullptr)
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& expectedInfo = expectedPtr->GetInfo();
92*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& actualInfo = actualPtr->GetInfo();
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(expectedInfo.GetShape() == actualInfo.GetShape(),
95*89c4ff92SAndroid Build Coastguard Worker (tensorName + " shapes don't match"));
96*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(
97*89c4ff92SAndroid Build Coastguard Worker GetDataTypeName(expectedInfo.GetDataType()) == GetDataTypeName(actualInfo.GetDataType()),
98*89c4ff92SAndroid Build Coastguard Worker (tensorName + " data types don't match"));
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(expectedPtr->GetNumBytes() == actualPtr->GetNumBytes(),
101*89c4ff92SAndroid Build Coastguard Worker (tensorName + " (GetNumBytes) data sizes do not match"));
102*89c4ff92SAndroid Build Coastguard Worker if (expectedPtr->GetNumBytes() == actualPtr->GetNumBytes())
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker //check the data is identical
105*89c4ff92SAndroid Build Coastguard Worker const char* expectedData = static_cast<const char*>(expectedPtr->GetMemoryArea());
106*89c4ff92SAndroid Build Coastguard Worker const char* actualData = static_cast<const char*>(actualPtr->GetMemoryArea());
107*89c4ff92SAndroid Build Coastguard Worker bool same = true;
108*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < expectedPtr->GetNumBytes(); ++i)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker same = expectedData[i] == actualData[i];
111*89c4ff92SAndroid Build Coastguard Worker if (!same)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker break;
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(same, (tensorName + " data does not match"));
117*89c4ff92SAndroid Build Coastguard Worker }
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker
CompareConstTensor(const armnn::ConstTensor & tensor1,const armnn::ConstTensor & tensor2)122*89c4ff92SAndroid Build Coastguard Worker void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2)
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker CHECK(tensor1.GetShape() == tensor2.GetShape());
125*89c4ff92SAndroid Build Coastguard Worker CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker switch (tensor1.GetDataType())
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Float32:
130*89c4ff92SAndroid Build Coastguard Worker CompareConstTensorData<const float*>(
131*89c4ff92SAndroid Build Coastguard Worker tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132*89c4ff92SAndroid Build Coastguard Worker break;
133*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QAsymmU8:
134*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Boolean:
135*89c4ff92SAndroid Build Coastguard Worker CompareConstTensorData<const uint8_t*>(
136*89c4ff92SAndroid Build Coastguard Worker tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137*89c4ff92SAndroid Build Coastguard Worker break;
138*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::QSymmS8:
139*89c4ff92SAndroid Build Coastguard Worker CompareConstTensorData<const int8_t*>(
140*89c4ff92SAndroid Build Coastguard Worker tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141*89c4ff92SAndroid Build Coastguard Worker break;
142*89c4ff92SAndroid Build Coastguard Worker case armnn::DataType::Signed32:
143*89c4ff92SAndroid Build Coastguard Worker CompareConstTensorData<const int32_t*>(
144*89c4ff92SAndroid Build Coastguard Worker tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145*89c4ff92SAndroid Build Coastguard Worker break;
146*89c4ff92SAndroid Build Coastguard Worker default:
147*89c4ff92SAndroid Build Coastguard Worker // Note that Float16 is not yet implemented
148*89c4ff92SAndroid Build Coastguard Worker MESSAGE("Unexpected datatype");
149*89c4ff92SAndroid Build Coastguard Worker CHECK(false);
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker
DeserializeNetwork(const std::string & serializerString)153*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString)
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156*89c4ff92SAndroid Build Coastguard Worker return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker
SerializeNetwork(const armnn::INetwork & network)159*89c4ff92SAndroid Build Coastguard Worker std::string SerializeNetwork(const armnn::INetwork& network)
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
162*89c4ff92SAndroid Build Coastguard Worker
163*89c4ff92SAndroid Build Coastguard Worker serializer->Serialize(network);
164*89c4ff92SAndroid Build Coastguard Worker
165*89c4ff92SAndroid Build Coastguard Worker std::stringstream stream;
166*89c4ff92SAndroid Build Coastguard Worker serializer->SaveSerializedToStream(stream);
167*89c4ff92SAndroid Build Coastguard Worker
168*89c4ff92SAndroid Build Coastguard Worker std::string serializerString{stream.str()};
169*89c4ff92SAndroid Build Coastguard Worker return serializerString;
170*89c4ff92SAndroid Build Coastguard Worker }
171