1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker #include <sstream> 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("SerializerTests") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker class VerifyActivationName : public armnn::IStrategy 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker public: ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)23*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(const armnn::IConnectableLayer* layer, 24*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& descriptor, 25*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::ConstTensor>& constants, 26*89c4ff92SAndroid Build Coastguard Worker const char* name, 27*89c4ff92SAndroid Build Coastguard Worker const armnn::LayerBindingId id = 0) override 28*89c4ff92SAndroid Build Coastguard Worker { 29*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(layer, descriptor, constants, id); 30*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType() == armnn::LayerType::Activation) 31*89c4ff92SAndroid Build Coastguard Worker { 32*89c4ff92SAndroid Build Coastguard Worker CHECK(std::string(name) == "activation"); 33*89c4ff92SAndroid Build Coastguard Worker } 34*89c4ff92SAndroid Build Coastguard Worker } 35*89c4ff92SAndroid Build Coastguard Worker }; 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ActivationSerialization") 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create(); 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0); 42*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0); 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker // Construct network 45*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = armnn::INetwork::Create(); 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker armnn::ActivationDescriptor descriptor; 48*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Function = armnn::ActivationFunction::ReLu; 49*89c4ff92SAndroid Build Coastguard Worker descriptor.m_A = 0; 50*89c4ff92SAndroid Build Coastguard Worker descriptor.m_B = 0; 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const inputLayer = network->AddInputLayer(0, "input"); 53*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation"); 54*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0, "output"); 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0)); 57*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo); 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); 60*89c4ff92SAndroid Build Coastguard Worker activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo); 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create(); 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker serializer->Serialize(*network); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker std::stringstream stream; 67*89c4ff92SAndroid Build Coastguard Worker serializer->SaveSerializedToStream(stream); 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker std::string const serializerString{stream.str()}; 70*89c4ff92SAndroid Build Coastguard Worker std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()}; 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector); 73*89c4ff92SAndroid Build Coastguard Worker 74*89c4ff92SAndroid Build Coastguard Worker VerifyActivationName visitor; 75*89c4ff92SAndroid Build Coastguard Worker deserializedNetwork->ExecuteStrategy(visitor); 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime::CreationOptions options; // default options 78*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntimePtr run = armnn::IRuntime::Create(options); 79*89c4ff92SAndroid Build Coastguard Worker auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec()); 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId networkIdentifier; 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker // Load graph into runtime 84*89c4ff92SAndroid Build Coastguard Worker run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized)); 85*89c4ff92SAndroid Build Coastguard Worker 86*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f}; 87*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0); 88*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true); 89*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors 90*89c4ff92SAndroid Build Coastguard Worker { 91*89c4ff92SAndroid Build Coastguard Worker {0, armnn::ConstTensor(inputTensorInfo, inputData.data())} 92*89c4ff92SAndroid Build Coastguard Worker }; 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f}; 95*89c4ff92SAndroid Build Coastguard Worker 96*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(4); 97*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors 98*89c4ff92SAndroid Build Coastguard Worker { 99*89c4ff92SAndroid Build Coastguard Worker {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())} 100*89c4ff92SAndroid Build Coastguard Worker }; 101*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors); 102*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutputData.begin(), expectedOutputData.end())); 103*89c4ff92SAndroid Build Coastguard Worker } 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker } 106