xref: /aosp_15_r20/external/armnn/src/armnnSerializer/test/ActivationSerializationTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../Serializer.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <sstream>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("SerializerTests")
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker class VerifyActivationName : public armnn::IStrategy
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker public:
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)23*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(const armnn::IConnectableLayer* layer,
24*89c4ff92SAndroid Build Coastguard Worker                          const armnn::BaseDescriptor& descriptor,
25*89c4ff92SAndroid Build Coastguard Worker                          const std::vector<armnn::ConstTensor>& constants,
26*89c4ff92SAndroid Build Coastguard Worker                          const char* name,
27*89c4ff92SAndroid Build Coastguard Worker                          const armnn::LayerBindingId id = 0) override
28*89c4ff92SAndroid Build Coastguard Worker     {
29*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(layer, descriptor, constants, id);
30*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType() == armnn::LayerType::Activation)
31*89c4ff92SAndroid Build Coastguard Worker         {
32*89c4ff92SAndroid Build Coastguard Worker             CHECK(std::string(name) == "activation");
33*89c4ff92SAndroid Build Coastguard Worker         }
34*89c4ff92SAndroid Build Coastguard Worker     }
35*89c4ff92SAndroid Build Coastguard Worker };
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ActivationSerialization")
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     armnnDeserializer::IDeserializerPtr parser = armnnDeserializer::IDeserializer::Create();
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 1.0f, 0);
42*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo(armnn::TensorShape({1, 2, 2, 1}), armnn::DataType::Float32, 4.0f, 0);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     // Construct network
45*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetwork::Create();
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     armnn::ActivationDescriptor descriptor;
48*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = armnn::ActivationFunction::ReLu;
49*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_A = 0;
50*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_B = 0;
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const inputLayer      = network->AddInputLayer(0, "input");
53*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const activationLayer = network->AddActivationLayer(descriptor, "activation");
54*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const outputLayer     = network->AddOutputLayer(0, "output");
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
57*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
60*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker     armnnSerializer::ISerializerPtr serializer = armnnSerializer::ISerializer::Create();
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     serializer->Serialize(*network);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     std::stringstream stream;
67*89c4ff92SAndroid Build Coastguard Worker     serializer->SaveSerializedToStream(stream);
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     std::string const serializerString{stream.str()};
70*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr deserializedNetwork = parser->CreateNetworkFromBinary(serializerVector);
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     VerifyActivationName visitor;
75*89c4ff92SAndroid Build Coastguard Worker     deserializedNetwork->ExecuteStrategy(visitor);
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime::CreationOptions options; // default options
78*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntimePtr run = armnn::IRuntime::Create(options);
79*89c4ff92SAndroid Build Coastguard Worker     auto deserializedOptimized = Optimize(*deserializedNetwork, { armnn::Compute::CpuRef }, run->GetDeviceSpec());
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId networkIdentifier;
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker     // Load graph into runtime
84*89c4ff92SAndroid Build Coastguard Worker     run->LoadNetwork(networkIdentifier, std::move(deserializedOptimized));
85*89c4ff92SAndroid Build Coastguard Worker 
86*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData {0.0f, -5.3f, 42.0f, -42.0f};
87*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
88*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
89*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
92*89c4ff92SAndroid Build Coastguard Worker     };
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputData {0.0f, 0.0f, 42.0f, 0.0f};
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(4);
97*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors
98*89c4ff92SAndroid Build Coastguard Worker     {
99*89c4ff92SAndroid Build Coastguard Worker         {0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData.data())}
100*89c4ff92SAndroid Build Coastguard Worker     };
101*89c4ff92SAndroid Build Coastguard Worker     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
102*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData.begin(), outputData.end(), expectedOutputData.begin(), expectedOutputData.end()));
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker }
106