1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "SerializerTestUtils.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("SerializerTests") 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker struct ComparisonModel 20*89c4ff92SAndroid Build Coastguard Worker { ComparisonModelComparisonModel21*89c4ff92SAndroid Build Coastguard Worker ComparisonModel(const std::string& layerName, 22*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo, 23*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo, 24*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor& descriptor) 25*89c4ff92SAndroid Build Coastguard Worker : m_network(armnn::INetwork::Create()) 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const inputLayer0 = m_network->AddInputLayer(0); 28*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const inputLayer1 = m_network->AddInputLayer(1); 29*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const equalLayer = m_network->AddComparisonLayer(descriptor, layerName.c_str()); 30*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const outputLayer = m_network->AddOutputLayer(0); 31*89c4ff92SAndroid Build Coastguard Worker 32*89c4ff92SAndroid Build Coastguard Worker inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0)); 33*89c4ff92SAndroid Build Coastguard Worker inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1)); 34*89c4ff92SAndroid Build Coastguard Worker equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo); 37*89c4ff92SAndroid Build Coastguard Worker inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo); 38*89c4ff92SAndroid Build Coastguard Worker equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 39*89c4ff92SAndroid Build Coastguard Worker } 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_network; 42*89c4ff92SAndroid Build Coastguard Worker }; 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker class ComparisonLayerVerifier : public LayerVerifierBase 45*89c4ff92SAndroid Build Coastguard Worker { 46*89c4ff92SAndroid Build Coastguard Worker public: ComparisonLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ComparisonDescriptor & descriptor)47*89c4ff92SAndroid Build Coastguard Worker ComparisonLayerVerifier(const std::string& layerName, 48*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorInfo>& inputInfos, 49*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorInfo>& outputInfos, 50*89c4ff92SAndroid Build Coastguard Worker const armnn::ComparisonDescriptor& descriptor) 51*89c4ff92SAndroid Build Coastguard Worker : LayerVerifierBase(layerName, inputInfos, outputInfos) 52*89c4ff92SAndroid Build Coastguard Worker , m_Descriptor (descriptor) {} 53*89c4ff92SAndroid Build Coastguard Worker ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)54*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(const armnn::IConnectableLayer* layer, 55*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& descriptor, 56*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::ConstTensor>& constants, 57*89c4ff92SAndroid Build Coastguard Worker const char* name, 58*89c4ff92SAndroid Build Coastguard Worker const armnn::LayerBindingId id = 0) override 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(descriptor, constants, id); 61*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType()) 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Input: break; 64*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Output: break; 65*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Comparison: 66*89c4ff92SAndroid Build Coastguard Worker { 67*89c4ff92SAndroid Build Coastguard Worker VerifyNameAndConnections(layer, name); 68*89c4ff92SAndroid Build Coastguard Worker const armnn::ComparisonDescriptor& layerDescriptor = 69*89c4ff92SAndroid Build Coastguard Worker static_cast<const armnn::ComparisonDescriptor&>(descriptor); 70*89c4ff92SAndroid Build Coastguard Worker CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation); 71*89c4ff92SAndroid Build Coastguard Worker break; 72*89c4ff92SAndroid Build Coastguard Worker } 73*89c4ff92SAndroid Build Coastguard Worker default: 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("Unexpected layer type in Comparison test model"); 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker } 78*89c4ff92SAndroid Build Coastguard Worker } 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker private: 81*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor m_Descriptor; 82*89c4ff92SAndroid Build Coastguard Worker }; 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeEqual") 85*89c4ff92SAndroid Build Coastguard Worker { 86*89c4ff92SAndroid Build Coastguard Worker const std::string layerName("equal"); 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape shape{2, 1, 2, 4}; 89*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32); 90*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); 91*89c4ff92SAndroid Build Coastguard Worker 92*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Equal); 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); 97*89c4ff92SAndroid Build Coastguard Worker CHECK(deserializedNetwork); 98*89c4ff92SAndroid Build Coastguard Worker 99*89c4ff92SAndroid Build Coastguard Worker ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); 100*89c4ff92SAndroid Build Coastguard Worker deserializedNetwork->ExecuteStrategy(verifier); 101*89c4ff92SAndroid Build Coastguard Worker } 102*89c4ff92SAndroid Build Coastguard Worker 103*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SerializeGreater") 104*89c4ff92SAndroid Build Coastguard Worker { 105*89c4ff92SAndroid Build Coastguard Worker const std::string layerName("greater"); 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape shape{2, 1, 2, 4}; 108*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32); 109*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Greater); 112*89c4ff92SAndroid Build Coastguard Worker 113*89c4ff92SAndroid Build Coastguard Worker ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); 116*89c4ff92SAndroid Build Coastguard Worker CHECK(deserializedNetwork); 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); 119*89c4ff92SAndroid Build Coastguard Worker deserializedNetwork->ExecuteStrategy(verifier); 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker } 123