xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/ComparisonSerializationTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "SerializerTestUtils.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("SerializerTests")
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker struct ComparisonModel
20*89c4ff92SAndroid Build Coastguard Worker {
ComparisonModelComparisonModel21*89c4ff92SAndroid Build Coastguard Worker     ComparisonModel(const std::string& layerName,
22*89c4ff92SAndroid Build Coastguard Worker                     const armnn::TensorInfo& inputInfo,
23*89c4ff92SAndroid Build Coastguard Worker                     const armnn::TensorInfo& outputInfo,
24*89c4ff92SAndroid Build Coastguard Worker                     armnn::ComparisonDescriptor& descriptor)
25*89c4ff92SAndroid Build Coastguard Worker             : m_network(armnn::INetwork::Create())
26*89c4ff92SAndroid Build Coastguard Worker     {
27*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const inputLayer0 = m_network->AddInputLayer(0);
28*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const inputLayer1 = m_network->AddInputLayer(1);
29*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const equalLayer = m_network->AddComparisonLayer(descriptor, layerName.c_str());
30*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* const outputLayer = m_network->AddOutputLayer(0);
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker         inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
33*89c4ff92SAndroid Build Coastguard Worker         inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
34*89c4ff92SAndroid Build Coastguard Worker         equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker         inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
37*89c4ff92SAndroid Build Coastguard Worker         inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
38*89c4ff92SAndroid Build Coastguard Worker         equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
39*89c4ff92SAndroid Build Coastguard Worker     }
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr m_network;
42*89c4ff92SAndroid Build Coastguard Worker };
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker class ComparisonLayerVerifier : public LayerVerifierBase
45*89c4ff92SAndroid Build Coastguard Worker {
46*89c4ff92SAndroid Build Coastguard Worker public:
ComparisonLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ComparisonDescriptor & descriptor)47*89c4ff92SAndroid Build Coastguard Worker     ComparisonLayerVerifier(const std::string& layerName,
48*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& inputInfos,
49*89c4ff92SAndroid Build Coastguard Worker                             const std::vector<armnn::TensorInfo>& outputInfos,
50*89c4ff92SAndroid Build Coastguard Worker                             const armnn::ComparisonDescriptor& descriptor)
51*89c4ff92SAndroid Build Coastguard Worker             : LayerVerifierBase(layerName, inputInfos, outputInfos)
52*89c4ff92SAndroid Build Coastguard Worker             , m_Descriptor (descriptor) {}
53*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)54*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
55*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
56*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
57*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
58*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker         armnn::IgnoreUnused(descriptor, constants, id);
61*89c4ff92SAndroid Build Coastguard Worker         switch (layer->GetType())
62*89c4ff92SAndroid Build Coastguard Worker         {
63*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Input: break;
64*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Output: break;
65*89c4ff92SAndroid Build Coastguard Worker             case armnn::LayerType::Comparison:
66*89c4ff92SAndroid Build Coastguard Worker             {
67*89c4ff92SAndroid Build Coastguard Worker                 VerifyNameAndConnections(layer, name);
68*89c4ff92SAndroid Build Coastguard Worker                 const armnn::ComparisonDescriptor& layerDescriptor =
69*89c4ff92SAndroid Build Coastguard Worker                         static_cast<const armnn::ComparisonDescriptor&>(descriptor);
70*89c4ff92SAndroid Build Coastguard Worker                 CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation);
71*89c4ff92SAndroid Build Coastguard Worker                 break;
72*89c4ff92SAndroid Build Coastguard Worker             }
73*89c4ff92SAndroid Build Coastguard Worker             default:
74*89c4ff92SAndroid Build Coastguard Worker             {
75*89c4ff92SAndroid Build Coastguard Worker                 throw armnn::Exception("Unexpected layer type in Comparison test model");
76*89c4ff92SAndroid Build Coastguard Worker             }
77*89c4ff92SAndroid Build Coastguard Worker         }
78*89c4ff92SAndroid Build Coastguard Worker     }
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker private:
81*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor m_Descriptor;
82*89c4ff92SAndroid Build Coastguard Worker };
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeEqual")
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("equal");
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 4};
89*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
90*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Equal);
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
97*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
98*89c4ff92SAndroid Build Coastguard Worker 
99*89c4ff92SAndroid Build Coastguard Worker     ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
100*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeGreater")
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker     const std::string layerName("greater");
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape shape{2, 1, 2, 4};
108*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
109*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker     armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Greater);
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
116*89c4ff92SAndroid Build Coastguard Worker     CHECK(deserializedNetwork);
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
119*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(verifier);
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker }
123