xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/SerializerTestUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <random>
13*89c4ff92SAndroid Build Coastguard Worker #include <vector>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <cstdlib>
16*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr DeserializeNetwork(const std::string& serializerString);
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker std::string SerializeNetwork(const armnn::INetwork& network);
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker void CompareConstTensor(const armnn::ConstTensor& tensor1, const armnn::ConstTensor& tensor2);
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker class LayerVerifierBase : public armnn::IStrategy
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker public:
27*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBase(const std::string& layerName,
28*89c4ff92SAndroid Build Coastguard Worker                       const std::vector<armnn::TensorInfo>& inputInfos,
29*89c4ff92SAndroid Build Coastguard Worker                       const std::vector<armnn::TensorInfo>& outputInfos);
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
32*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
33*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
34*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
35*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override;
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker protected:
38*89c4ff92SAndroid Build Coastguard Worker     void VerifyNameAndConnections(const armnn::IConnectableLayer* layer, const char* name);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     void VerifyConstTensors(const std::string& tensorName,
41*89c4ff92SAndroid Build Coastguard Worker                             const armnn::ConstTensor* expectedPtr,
42*89c4ff92SAndroid Build Coastguard Worker                             const armnn::ConstTensor* actualPtr);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker private:
45*89c4ff92SAndroid Build Coastguard Worker     std::string m_LayerName;
46*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> m_InputTensorInfos;
47*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::TensorInfo> m_OutputTensorInfos;
48*89c4ff92SAndroid Build Coastguard Worker };
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker template<typename Descriptor>
51*89c4ff92SAndroid Build Coastguard Worker class LayerVerifierBaseWithDescriptor : public LayerVerifierBase
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker public:
LayerVerifierBaseWithDescriptor(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const Descriptor & descriptor)54*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptor(const std::string& layerName,
55*89c4ff92SAndroid Build Coastguard Worker                                     const std::vector<armnn::TensorInfo>& inputInfos,
56*89c4ff92SAndroid Build Coastguard Worker                                     const std::vector<armnn::TensorInfo>& outputInfos,
57*89c4ff92SAndroid Build Coastguard Worker                                     const Descriptor& descriptor)
58*89c4ff92SAndroid Build Coastguard Worker         : LayerVerifierBase(layerName, inputInfos, outputInfos)
59*89c4ff92SAndroid Build Coastguard Worker         , m_Descriptor(descriptor) {}
60*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)61*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
62*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
63*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
64*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
65*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
66*89c4ff92SAndroid Build Coastguard Worker     {
67*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(constants, id);
68*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
69*89c4ff92SAndroid Build Coastguard Worker         {
70*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
71*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
72*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Constant: break;
73*89c4ff92SAndroid Build Coastguard Worker             default:
74*89c4ff92SAndroid Build Coastguard Worker             {
75*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
76*89c4ff92SAndroid Build Coastguard Worker                 const Descriptor& internalDescriptor = static_cast<const Descriptor&>(descriptor);
77*89c4ff92SAndroid Build Coastguard Worker                 VerifyDescriptor(internalDescriptor);
78*89c4ff92SAndroid Build Coastguard Worker                 break;
79*89c4ff92SAndroid Build Coastguard Worker             }
80*89c4ff92SAndroid Build Coastguard Worker         }
81*89c4ff92SAndroid Build Coastguard Worker     }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker protected:
VerifyDescriptor(const Descriptor & descriptor)84*89c4ff92SAndroid Build Coastguard Worker     void VerifyDescriptor(const Descriptor& descriptor)
85*89c4ff92SAndroid Build Coastguard Worker     {
86*89c4ff92SAndroid Build Coastguard Worker         CHECK(descriptor == m_Descriptor);
87*89c4ff92SAndroid Build Coastguard Worker     }
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     Descriptor m_Descriptor;
90*89c4ff92SAndroid Build Coastguard Worker };
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker template<typename T>
CompareConstTensorData(const void * data1,const void * data2,unsigned int numElements)93*89c4ff92SAndroid Build Coastguard Worker void CompareConstTensorData(const void* data1, const void* data2, unsigned int numElements)
94*89c4ff92SAndroid Build Coastguard Worker {
95*89c4ff92SAndroid Build Coastguard Worker     T typedData1 = static_cast<T>(data1);
96*89c4ff92SAndroid Build Coastguard Worker     T typedData2 = static_cast<T>(data2);
97*89c4ff92SAndroid Build Coastguard Worker     CHECK(typedData1);
98*89c4ff92SAndroid Build Coastguard Worker     CHECK(typedData2);
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numElements; i++)
101*89c4ff92SAndroid Build Coastguard Worker     {
102*89c4ff92SAndroid Build Coastguard Worker         CHECK(typedData1[i] == typedData2[i]);
103*89c4ff92SAndroid Build Coastguard Worker     }
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker template <typename Descriptor>
108*89c4ff92SAndroid Build Coastguard Worker class LayerVerifierBaseWithDescriptorAndConstants : public LayerVerifierBaseWithDescriptor<Descriptor>
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker public:
LayerVerifierBaseWithDescriptorAndConstants(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const Descriptor & descriptor,const std::vector<armnn::ConstTensor> & constants)111*89c4ff92SAndroid Build Coastguard Worker     LayerVerifierBaseWithDescriptorAndConstants(const std::string& layerName,
112*89c4ff92SAndroid Build Coastguard Worker                                                 const std::vector<armnn::TensorInfo>& inputInfos,
113*89c4ff92SAndroid Build Coastguard Worker                                                 const std::vector<armnn::TensorInfo>& outputInfos,
114*89c4ff92SAndroid Build Coastguard Worker                                                 const Descriptor& descriptor,
115*89c4ff92SAndroid Build Coastguard Worker                                                 const std::vector<armnn::ConstTensor>& constants)
116*89c4ff92SAndroid Build Coastguard Worker             : LayerVerifierBaseWithDescriptor<Descriptor>(layerName, inputInfos, outputInfos, descriptor)
117*89c4ff92SAndroid Build Coastguard Worker             , m_Constants(constants) {}
118*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)119*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
120*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
121*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
122*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
123*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
124*89c4ff92SAndroid Build Coastguard Worker     {
125*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(id);
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
128*89c4ff92SAndroid Build Coastguard Worker         {
129*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
130*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
131*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Constant: break;
132*89c4ff92SAndroid Build Coastguard Worker             default:
133*89c4ff92SAndroid Build Coastguard Worker             {
134*89c4ff92SAndroid Build Coastguard Worker                 this->VerifyNameAndConnections(layer, name);
135*89c4ff92SAndroid Build Coastguard Worker                 const Descriptor& internalDescriptor = static_cast<const Descriptor&>(descriptor);
136*89c4ff92SAndroid Build Coastguard Worker                 this->VerifyDescriptor(internalDescriptor);
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker                 for(std::size_t i = 0; i < constants.size(); i++)
139*89c4ff92SAndroid Build Coastguard Worker                 {
140*89c4ff92SAndroid Build Coastguard Worker                     CompareConstTensor(constants[i], m_Constants[i]);
141*89c4ff92SAndroid Build Coastguard Worker                 }
142*89c4ff92SAndroid Build Coastguard Worker             }
143*89c4ff92SAndroid Build Coastguard Worker         }
144*89c4ff92SAndroid Build Coastguard Worker     }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker private:
147*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::ConstTensor> m_Constants;
148*89c4ff92SAndroid Build Coastguard Worker };
149*89c4ff92SAndroid Build Coastguard Worker 
150*89c4ff92SAndroid Build Coastguard Worker template<typename DataType>
GenerateRandomData(size_t size)151*89c4ff92SAndroid Build Coastguard Worker static std::vector<DataType> GenerateRandomData(size_t size)
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker     constexpr bool isIntegerType = std::is_integral<DataType>::value;
154*89c4ff92SAndroid Build Coastguard Worker     using Distribution =
155*89c4ff92SAndroid Build Coastguard Worker         typename std::conditional<isIntegerType,
156*89c4ff92SAndroid Build Coastguard Worker                                   std::uniform_int_distribution<DataType>,
157*89c4ff92SAndroid Build Coastguard Worker                                   std::uniform_real_distribution<DataType>>::type;
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     static constexpr DataType lowerLimit = std::numeric_limits<DataType>::min();
160*89c4ff92SAndroid Build Coastguard Worker     static constexpr DataType upperLimit = std::numeric_limits<DataType>::max();
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     static Distribution distribution(lowerLimit, upperLimit);
163*89c4ff92SAndroid Build Coastguard Worker     static std::default_random_engine generator;
164*89c4ff92SAndroid Build Coastguard Worker 
165*89c4ff92SAndroid Build Coastguard Worker     std::vector<DataType> randomData(size);
166*89c4ff92SAndroid Build Coastguard Worker     generate(randomData.begin(), randomData.end(), []() { return distribution(generator); });
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker     return randomData;
169*89c4ff92SAndroid Build Coastguard Worker }