xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/ComparisonSerializationTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../Serializer.hpp"
7 #include "SerializerTestUtils.hpp"
8 
9 #include <armnn/Descriptors.hpp>
10 #include <armnn/INetwork.hpp>
11 #include <armnn/IRuntime.hpp>
12 #include <armnnDeserializer/IDeserializer.hpp>
13 #include <armnn/utility/IgnoreUnused.hpp>
14 
15 #include <doctest/doctest.h>
16 
17 TEST_SUITE("SerializerTests")
18 {
19 struct ComparisonModel
20 {
ComparisonModelComparisonModel21     ComparisonModel(const std::string& layerName,
22                     const armnn::TensorInfo& inputInfo,
23                     const armnn::TensorInfo& outputInfo,
24                     armnn::ComparisonDescriptor& descriptor)
25             : m_network(armnn::INetwork::Create())
26     {
27         armnn::IConnectableLayer* const inputLayer0 = m_network->AddInputLayer(0);
28         armnn::IConnectableLayer* const inputLayer1 = m_network->AddInputLayer(1);
29         armnn::IConnectableLayer* const equalLayer = m_network->AddComparisonLayer(descriptor, layerName.c_str());
30         armnn::IConnectableLayer* const outputLayer = m_network->AddOutputLayer(0);
31 
32         inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
33         inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
34         equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
35 
36         inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo);
37         inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo);
38         equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
39     }
40 
41     armnn::INetworkPtr m_network;
42 };
43 
44 class ComparisonLayerVerifier : public LayerVerifierBase
45 {
46 public:
ComparisonLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ComparisonDescriptor & descriptor)47     ComparisonLayerVerifier(const std::string& layerName,
48                             const std::vector<armnn::TensorInfo>& inputInfos,
49                             const std::vector<armnn::TensorInfo>& outputInfos,
50                             const armnn::ComparisonDescriptor& descriptor)
51             : LayerVerifierBase(layerName, inputInfos, outputInfos)
52             , m_Descriptor (descriptor) {}
53 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)54     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
55                          const armnn::BaseDescriptor& descriptor,
56                          const std::vector<armnn::ConstTensor>& constants,
57                          const char* name,
58                          const armnn::LayerBindingId id = 0) override
59     {
60         armnn::IgnoreUnused(descriptor, constants, id);
61         switch (layer->GetType())
62         {
63             case armnn::LayerType::Input: break;
64             case armnn::LayerType::Output: break;
65             case armnn::LayerType::Comparison:
66             {
67                 VerifyNameAndConnections(layer, name);
68                 const armnn::ComparisonDescriptor& layerDescriptor =
69                         static_cast<const armnn::ComparisonDescriptor&>(descriptor);
70                 CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation);
71                 break;
72             }
73             default:
74             {
75                 throw armnn::Exception("Unexpected layer type in Comparison test model");
76             }
77         }
78     }
79 
80 private:
81     armnn::ComparisonDescriptor m_Descriptor;
82 };
83 
84 TEST_CASE("SerializeEqual")
85 {
86     const std::string layerName("equal");
87 
88     const armnn::TensorShape shape{2, 1, 2, 4};
89     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
90     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
91 
92     armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Equal);
93 
94     ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
95 
96     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
97     CHECK(deserializedNetwork);
98 
99     ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
100     deserializedNetwork->ExecuteStrategy(verifier);
101 }
102 
103 TEST_CASE("SerializeGreater")
104 {
105     const std::string layerName("greater");
106 
107     const armnn::TensorShape shape{2, 1, 2, 4};
108     const armnn::TensorInfo inputInfo  = armnn::TensorInfo(shape, armnn::DataType::Float32);
109     const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean);
110 
111     armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Greater);
112 
113     ComparisonModel model(layerName, inputInfo, outputInfo, descriptor);
114 
115     armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network));
116     CHECK(deserializedNetwork);
117 
118     ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor);
119     deserializedNetwork->ExecuteStrategy(verifier);
120 }
121 
122 }
123