1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "../Serializer.hpp" 7 #include "SerializerTestUtils.hpp" 8 9 #include <armnn/Descriptors.hpp> 10 #include <armnn/INetwork.hpp> 11 #include <armnn/IRuntime.hpp> 12 #include <armnnDeserializer/IDeserializer.hpp> 13 #include <armnn/utility/IgnoreUnused.hpp> 14 15 #include <doctest/doctest.h> 16 17 TEST_SUITE("SerializerTests") 18 { 19 struct ComparisonModel 20 { ComparisonModelComparisonModel21 ComparisonModel(const std::string& layerName, 22 const armnn::TensorInfo& inputInfo, 23 const armnn::TensorInfo& outputInfo, 24 armnn::ComparisonDescriptor& descriptor) 25 : m_network(armnn::INetwork::Create()) 26 { 27 armnn::IConnectableLayer* const inputLayer0 = m_network->AddInputLayer(0); 28 armnn::IConnectableLayer* const inputLayer1 = m_network->AddInputLayer(1); 29 armnn::IConnectableLayer* const equalLayer = m_network->AddComparisonLayer(descriptor, layerName.c_str()); 30 armnn::IConnectableLayer* const outputLayer = m_network->AddOutputLayer(0); 31 32 inputLayer0->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0)); 33 inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1)); 34 equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 35 36 inputLayer0->GetOutputSlot(0).SetTensorInfo(inputInfo); 37 inputLayer1->GetOutputSlot(0).SetTensorInfo(inputInfo); 38 equalLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 39 } 40 41 armnn::INetworkPtr m_network; 42 }; 43 44 class ComparisonLayerVerifier : public LayerVerifierBase 45 { 46 public: ComparisonLayerVerifier(const std::string & layerName,const std::vector<armnn::TensorInfo> & inputInfos,const std::vector<armnn::TensorInfo> & outputInfos,const armnn::ComparisonDescriptor & descriptor)47 ComparisonLayerVerifier(const std::string& layerName, 48 const std::vector<armnn::TensorInfo>& inputInfos, 49 const std::vector<armnn::TensorInfo>& outputInfos, 50 const armnn::ComparisonDescriptor& descriptor) 51 : LayerVerifierBase(layerName, inputInfos, outputInfos) 52 , m_Descriptor (descriptor) {} 53 ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)54 void ExecuteStrategy(const armnn::IConnectableLayer* layer, 55 const armnn::BaseDescriptor& descriptor, 56 const std::vector<armnn::ConstTensor>& constants, 57 const char* name, 58 const armnn::LayerBindingId id = 0) override 59 { 60 armnn::IgnoreUnused(descriptor, constants, id); 61 switch (layer->GetType()) 62 { 63 case armnn::LayerType::Input: break; 64 case armnn::LayerType::Output: break; 65 case armnn::LayerType::Comparison: 66 { 67 VerifyNameAndConnections(layer, name); 68 const armnn::ComparisonDescriptor& layerDescriptor = 69 static_cast<const armnn::ComparisonDescriptor&>(descriptor); 70 CHECK(layerDescriptor.m_Operation == m_Descriptor.m_Operation); 71 break; 72 } 73 default: 74 { 75 throw armnn::Exception("Unexpected layer type in Comparison test model"); 76 } 77 } 78 } 79 80 private: 81 armnn::ComparisonDescriptor m_Descriptor; 82 }; 83 84 TEST_CASE("SerializeEqual") 85 { 86 const std::string layerName("equal"); 87 88 const armnn::TensorShape shape{2, 1, 2, 4}; 89 const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32); 90 const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); 91 92 armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Equal); 93 94 ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); 95 96 armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); 97 CHECK(deserializedNetwork); 98 99 ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); 100 deserializedNetwork->ExecuteStrategy(verifier); 101 } 102 103 TEST_CASE("SerializeGreater") 104 { 105 const std::string layerName("greater"); 106 107 const armnn::TensorShape shape{2, 1, 2, 4}; 108 const armnn::TensorInfo inputInfo = armnn::TensorInfo(shape, armnn::DataType::Float32); 109 const armnn::TensorInfo outputInfo = armnn::TensorInfo(shape, armnn::DataType::Boolean); 110 111 armnn::ComparisonDescriptor descriptor (armnn::ComparisonOperation::Greater); 112 113 ComparisonModel model(layerName, inputInfo, outputInfo, descriptor); 114 115 armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*model.m_network)); 116 CHECK(deserializedNetwork); 117 118 ComparisonLayerVerifier verifier(layerName, { inputInfo, inputInfo }, { outputInfo }, descriptor); 119 deserializedNetwork->ExecuteStrategy(verifier); 120 } 121 122 } 123