xref: /aosp_15_r20/external/armnn/samples/SimpleSample.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Utils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker /// A simple example of using the ArmNN SDK API. In this sample, the users single input number is multiplied by 1.0f
13*89c4ff92SAndroid Build Coastguard Worker /// using a fully connected layer with a single neuron to produce an output number that is the same as the input.
main()14*89c4ff92SAndroid Build Coastguard Worker int main()
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker     float number;
19*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Please enter a number: " << std::endl;
20*89c4ff92SAndroid Build Coastguard Worker     std::cin >> number;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     // Turn on logging to standard output
23*89c4ff92SAndroid Build Coastguard Worker     // This is useful in this sample so that users can learn more about what is going on
24*89c4ff92SAndroid Build Coastguard Worker     ConfigureLogging(true, false, LogSeverity::Warning);
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
27*89c4ff92SAndroid Build Coastguard Worker     NetworkId networkIdentifier;
28*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr myNetwork = INetwork::Create();
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     float weightsData[] = {1.0f}; // Identity
31*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32, 0.0f, 0, true);
32*89c4ff92SAndroid Build Coastguard Worker     weightsInfo.SetConstant();
33*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, weightsData);
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     // Constant layer that now holds weights data for FullyConnected
36*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "const weights");
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor fullyConnectedDesc;
39*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
40*89c4ff92SAndroid Build Coastguard Worker                                                                                      "fully connected");
41*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* InputLayer  = myNetwork->AddInputLayer(0);
42*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* OutputLayer = myNetwork->AddOutputLayer(0);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
45*89c4ff92SAndroid Build Coastguard Worker     constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
46*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).Connect(OutputLayer->GetInputSlot(0));
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     // Create ArmNN runtime
49*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options; // default options
50*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr run = IRuntime::Create(options);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     //Set the tensors in the network.
53*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
54*89c4ff92SAndroid Build Coastguard Worker     InputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
57*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
58*89c4ff92SAndroid Build Coastguard Worker     constantWeightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     // Optimise ArmNN network
61*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet = Optimize(*myNetwork, {Compute::CpuRef}, run->GetDeviceSpec());
62*89c4ff92SAndroid Build Coastguard Worker     if (!optNet)
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         // This shouldn't happen for this simple sample, with reference backend.
65*89c4ff92SAndroid Build Coastguard Worker         // But in general usage Optimize could fail if the hardware at runtime cannot
66*89c4ff92SAndroid Build Coastguard Worker         // support the model that has been provided.
67*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Error: Failed to optimise the input network." << std::endl;
68*89c4ff92SAndroid Build Coastguard Worker         return 1;
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     // Load graph into runtime
72*89c4ff92SAndroid Build Coastguard Worker     run->LoadNetwork(networkIdentifier, std::move(optNet));
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     //Creates structures for inputs and outputs.
75*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData{number};
76*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> outputData(1);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
79*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
80*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors{{0, armnn::ConstTensor(inputTensorInfo,
81*89c4ff92SAndroid Build Coastguard Worker                                                      inputData.data())}};
82*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors{{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0),
83*89c4ff92SAndroid Build Coastguard Worker                                                   outputData.data())}};
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     // Execute network
86*89c4ff92SAndroid Build Coastguard Worker     run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Your number was " << outputData[0] << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker     return 0;
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker }
92