xref: /aosp_15_r20/external/armnn/src/backends/cl/test/ClContextSerializerTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020, 2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <cl/test/ClContextControlFixture.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <fstream>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker 
CreateNetwork()17*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateNetwork()
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network.
20*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net(armnn::INetwork::Create());
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* input = net->AddInputLayer(0, "input");
23*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* softmax = net->AddSoftmaxLayer(armnn::SoftmaxDescriptor(), "softmax");
24*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* output  = net->AddOutputLayer(0, "output");
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(softmax->GetInputSlot(0));
27*89c4ff92SAndroid Build Coastguard Worker     softmax->GetOutputSlot(0).Connect(output->GetInputSlot(0));
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     // Sets the input and output tensors
30*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 10000.0f, 1);
31*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 1.0f/255.0f, 0);
34*89c4ff92SAndroid Build Coastguard Worker     softmax->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     return net;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker 
RunInference(armnn::NetworkId & netId,armnn::IRuntimePtr & runtime,std::vector<uint8_t> & outputData)39*89c4ff92SAndroid Build Coastguard Worker void RunInference(armnn::NetworkId& netId, armnn::IRuntimePtr& runtime, std::vector<uint8_t>& outputData)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output.
42*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> inputData
43*89c4ff92SAndroid Build Coastguard Worker     {
44*89c4ff92SAndroid Build Coastguard Worker         1, 10, 3, 200, 5 // Some inputs - one of which is sufficiently larger than the others to saturate softmax.
45*89c4ff92SAndroid Build Coastguard Worker     };
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = runtime->GetInputTensorInfo(netId, 0);
48*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
49*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors
50*89c4ff92SAndroid Build Coastguard Worker     {
51*89c4ff92SAndroid Build Coastguard Worker         {0, armnn::ConstTensor(inputTensorInfo, inputData.data())}
52*89c4ff92SAndroid Build Coastguard Worker     };
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors
55*89c4ff92SAndroid Build Coastguard Worker     {
56*89c4ff92SAndroid Build Coastguard Worker         {0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
57*89c4ff92SAndroid Build Coastguard Worker     };
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     // Run inference.
60*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
ReadBinaryFile(const std::string & binaryFileName)63*89c4ff92SAndroid Build Coastguard Worker std::vector<char> ReadBinaryFile(const std::string& binaryFileName)
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     std::ifstream input(binaryFileName, std::ios::binary);
66*89c4ff92SAndroid Build Coastguard Worker     return std::vector<char>(std::istreambuf_iterator<char>(input), {});
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ClContextControlFixture, "ClContextSerializerTest")
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker     // Get tmp directory and create blank file.
74*89c4ff92SAndroid Build Coastguard Worker     fs::path filePath = armnnUtils::Filesystem::NamedTempFile("Armnn-CachedNetworkFileTest-TempFile.bin");
75*89c4ff92SAndroid Build Coastguard Worker     std::string const filePathString{filePath.string()};
76*89c4ff92SAndroid Build Coastguard Worker     std::ofstream file { filePathString };
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     // Create runtime in which test will run
79*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime::CreationOptions options;
80*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker     // Create two networks.
85*89c4ff92SAndroid Build Coastguard Worker     // net1 will serialize and save context to file.
86*89c4ff92SAndroid Build Coastguard Worker     // net2 will deserialize context saved from net1 and load.
87*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net1 = CreateNetwork();
88*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net2 = CreateNetwork();
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     // Add specific optimizerOptions to each network.
91*89c4ff92SAndroid Build Coastguard Worker     armnn::OptimizerOptionsOpaque optimizerOptions1;
92*89c4ff92SAndroid Build Coastguard Worker     armnn::OptimizerOptionsOpaque optimizerOptions2;
93*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions modelOptions1("GpuAcc",
94*89c4ff92SAndroid Build Coastguard Worker                                        {{"SaveCachedNetwork", true}, {"CachedNetworkFilePath", filePathString}});
95*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions modelOptions2("GpuAcc",
96*89c4ff92SAndroid Build Coastguard Worker                                         {{"SaveCachedNetwork", false}, {"CachedNetworkFilePath", filePathString}});
97*89c4ff92SAndroid Build Coastguard Worker     optimizerOptions1.AddModelOption(modelOptions1);
98*89c4ff92SAndroid Build Coastguard Worker     optimizerOptions2.AddModelOption(modelOptions2);
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet1 = armnn::Optimize(
101*89c4ff92SAndroid Build Coastguard Worker             *net1, backends, runtime->GetDeviceSpec(), optimizerOptions1);
102*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet2 = armnn::Optimize(
103*89c4ff92SAndroid Build Coastguard Worker             *net2, backends, runtime->GetDeviceSpec(), optimizerOptions2);
104*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet1);
105*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet2);
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     // Cached file should be empty until net1 is loaded into runtime.
108*89c4ff92SAndroid Build Coastguard Worker     CHECK(fs::is_empty(filePathString));
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     // Load net1 into the runtime.
111*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId netId1;
112*89c4ff92SAndroid Build Coastguard Worker     CHECK(runtime->LoadNetwork(netId1, std::move(optNet1)) == armnn::Status::Success);
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     // File should now exist and not be empty. It has been serialized.
115*89c4ff92SAndroid Build Coastguard Worker     CHECK(fs::exists(filePathString));
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<char> dataSerialized = ReadBinaryFile(filePathString);
117*89c4ff92SAndroid Build Coastguard Worker     CHECK(dataSerialized.size() != 0);
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     // Load net2 into the runtime using file and deserialize.
120*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId netId2;
121*89c4ff92SAndroid Build Coastguard Worker     CHECK(runtime->LoadNetwork(netId2, std::move(optNet2)) == armnn::Status::Success);
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     // Run inference and get output data.
124*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> outputData1(5);
125*89c4ff92SAndroid Build Coastguard Worker     RunInference(netId1, runtime, outputData1);
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> outputData2(5);
128*89c4ff92SAndroid Build Coastguard Worker     RunInference(netId2, runtime, outputData2);
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     // Compare outputs from both networks.
131*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::equal(outputData1.begin(), outputData1.end(), outputData2.begin(), outputData2.end()));
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker     // Remove temp file created.
134*89c4ff92SAndroid Build Coastguard Worker     fs::remove(filePath);
135*89c4ff92SAndroid Build Coastguard Worker }
136