1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Utils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
11*89c4ff92SAndroid Build Coastguard Worker #include <thread>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker /// A simple example of using the ArmNN SDK API to run a network multiple times with different inputs in an asynchronous
14*89c4ff92SAndroid Build Coastguard Worker /// manner.
15*89c4ff92SAndroid Build Coastguard Worker ///
16*89c4ff92SAndroid Build Coastguard Worker /// Background info: The usual runtime->EnqueueWorkload, which is used to trigger the execution of a network, is not
17*89c4ff92SAndroid Build Coastguard Worker /// thread safe. Each workload has memory assigned to it which would be overwritten by each thread.
18*89c4ff92SAndroid Build Coastguard Worker /// Before we added support for this you had to load a network multiple times to execute it at the
19*89c4ff92SAndroid Build Coastguard Worker /// same time. Every time a network is loaded, it takes up memory on your device. Making the
20*89c4ff92SAndroid Build Coastguard Worker /// execution thread safe helps to reduce the memory footprint for concurrent executions significantly.
21*89c4ff92SAndroid Build Coastguard Worker /// This example shows you how to execute a model concurrently (multiple threads) while still only
22*89c4ff92SAndroid Build Coastguard Worker /// loading it once.
23*89c4ff92SAndroid Build Coastguard Worker ///
24*89c4ff92SAndroid Build Coastguard Worker /// As in most of our simple samples, the network in this example will ask the user for a single input number for each
25*89c4ff92SAndroid Build Coastguard Worker /// execution of the network.
26*89c4ff92SAndroid Build Coastguard Worker /// The network consists of a single fully connected layer with a single neuron. The neurons weight is set to 1.0f
27*89c4ff92SAndroid Build Coastguard Worker /// to produce an output number that is the same as the input.
main()28*89c4ff92SAndroid Build Coastguard Worker int main()
29*89c4ff92SAndroid Build Coastguard Worker {
30*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker // The first part of this code is very similar to the SimpleSample.cpp you should check it out for comparison
33*89c4ff92SAndroid Build Coastguard Worker // The interesting part starts when the graph is loaded into the runtime
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputs;
36*89c4ff92SAndroid Build Coastguard Worker float number1;
37*89c4ff92SAndroid Build Coastguard Worker std::cout << "Please enter a number for the first iteration: " << std::endl;
38*89c4ff92SAndroid Build Coastguard Worker std::cin >> number1;
39*89c4ff92SAndroid Build Coastguard Worker float number2;
40*89c4ff92SAndroid Build Coastguard Worker std::cout << "Please enter a number for the second iteration: " << std::endl;
41*89c4ff92SAndroid Build Coastguard Worker std::cin >> number2;
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker // Turn on logging to standard output
44*89c4ff92SAndroid Build Coastguard Worker // This is useful in this sample so that users can learn more about what is going on
45*89c4ff92SAndroid Build Coastguard Worker ConfigureLogging(true, false, LogSeverity::Warning);
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
48*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier;
49*89c4ff92SAndroid Build Coastguard Worker INetworkPtr myNetwork = INetwork::Create();
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker float weightsData[] = {1.0f}; // Identity
52*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32, 0.0f, 0, true);
53*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetConstant();
54*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsData);
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker // Constant layer that now holds weights data for FullyConnected
57*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "const weights");
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor fullyConnectedDesc;
60*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
61*89c4ff92SAndroid Build Coastguard Worker "fully connected");
62*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* InputLayer = myNetwork->AddInputLayer(0);
63*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* OutputLayer = myNetwork->AddOutputLayer(0);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
66*89c4ff92SAndroid Build Coastguard Worker constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
67*89c4ff92SAndroid Build Coastguard Worker fullyConnectedLayer->GetOutputSlot(0).Connect(OutputLayer->GetInputSlot(0));
68*89c4ff92SAndroid Build Coastguard Worker
69*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime
70*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; // default options
71*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(options);
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker //Set the tensors in the network.
74*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
75*89c4ff92SAndroid Build Coastguard Worker InputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
78*89c4ff92SAndroid Build Coastguard Worker fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
79*89c4ff92SAndroid Build Coastguard Worker constantWeightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker // Optimise ArmNN network
82*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*myNetwork, {Compute::CpuRef}, run->GetDeviceSpec());
83*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker // This shouldn't happen for this simple sample, with reference backend.
86*89c4ff92SAndroid Build Coastguard Worker // But in general usage Optimize could fail if the hardware at runtime cannot
87*89c4ff92SAndroid Build Coastguard Worker // support the model that has been provided.
88*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Error: Failed to optimise the input network." << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker return 1;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker // Load graph into runtime.
93*89c4ff92SAndroid Build Coastguard Worker std::string errmsg; // To hold an eventual error message if loading the network fails
94*89c4ff92SAndroid Build Coastguard Worker // Add network properties to enable async execution. The MemorySource::Undefined variables indicate
95*89c4ff92SAndroid Build Coastguard Worker // that neither inputs nor outputs will be imported. Importing will be covered in another example.
96*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(true, MemorySource::Undefined, MemorySource::Undefined);
97*89c4ff92SAndroid Build Coastguard Worker run->LoadNetwork(networkIdentifier,
98*89c4ff92SAndroid Build Coastguard Worker std::move(optNet),
99*89c4ff92SAndroid Build Coastguard Worker errmsg,
100*89c4ff92SAndroid Build Coastguard Worker networkProperties);
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker // Creates structures for inputs and outputs. A vector of float for each execution.
103*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<float>> inputData{{number1}, {number2}};
104*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<float>> outputData;
105*89c4ff92SAndroid Build Coastguard Worker outputData.resize(2, std::vector<float>(1));
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
108*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
109*89c4ff92SAndroid Build Coastguard Worker std::vector<InputTensors> inputTensors
110*89c4ff92SAndroid Build Coastguard Worker {
111*89c4ff92SAndroid Build Coastguard Worker {{0, armnn::ConstTensor(inputTensorInfo, inputData[0].data())}},
112*89c4ff92SAndroid Build Coastguard Worker {{0, armnn::ConstTensor(inputTensorInfo, inputData[1].data())}}
113*89c4ff92SAndroid Build Coastguard Worker };
114*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputTensors> outputTensors
115*89c4ff92SAndroid Build Coastguard Worker {
116*89c4ff92SAndroid Build Coastguard Worker {{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData[0].data())}},
117*89c4ff92SAndroid Build Coastguard Worker {{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0), outputData[1].data())}}
118*89c4ff92SAndroid Build Coastguard Worker };
119*89c4ff92SAndroid Build Coastguard Worker
120*89c4ff92SAndroid Build Coastguard Worker // Lambda function to execute the network. We use it as thread function.
121*89c4ff92SAndroid Build Coastguard Worker auto execute = [&](unsigned int executionIndex)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker auto memHandle = run->CreateWorkingMemHandle(networkIdentifier);
124*89c4ff92SAndroid Build Coastguard Worker run->Execute(*memHandle, inputTensors[executionIndex], outputTensors[executionIndex]);
125*89c4ff92SAndroid Build Coastguard Worker };
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker // Prepare some threads and let each execute the network with a different input
128*89c4ff92SAndroid Build Coastguard Worker std::vector<std::thread> threads;
129*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputTensors.size(); ++i)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker threads.emplace_back(std::thread(execute, i));
132*89c4ff92SAndroid Build Coastguard Worker }
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker // Wait for the threads to finish
135*89c4ff92SAndroid Build Coastguard Worker for (std::thread& t : threads)
136*89c4ff92SAndroid Build Coastguard Worker {
137*89c4ff92SAndroid Build Coastguard Worker if(t.joinable())
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker t.join();
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker std::cout << "Your numbers were " << outputData[0][0] << " and " << outputData[1][0] << std::endl;
144*89c4ff92SAndroid Build Coastguard Worker return 0;
145*89c4ff92SAndroid Build Coastguard Worker
146*89c4ff92SAndroid Build Coastguard Worker }
147