1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Utils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker /// A simple example of using the ArmNN SDK API. In this sample, the users single input number is multiplied by 1.0f
13*89c4ff92SAndroid Build Coastguard Worker /// using a fully connected layer with a single neuron to produce an output number that is the same as the input.
main()14*89c4ff92SAndroid Build Coastguard Worker int main()
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker float number;
19*89c4ff92SAndroid Build Coastguard Worker std::cout << "Please enter a number: " << std::endl;
20*89c4ff92SAndroid Build Coastguard Worker std::cin >> number;
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker // Turn on logging to standard output
23*89c4ff92SAndroid Build Coastguard Worker // This is useful in this sample so that users can learn more about what is going on
24*89c4ff92SAndroid Build Coastguard Worker ConfigureLogging(true, false, LogSeverity::Warning);
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker // Construct ArmNN network
27*89c4ff92SAndroid Build Coastguard Worker NetworkId networkIdentifier;
28*89c4ff92SAndroid Build Coastguard Worker INetworkPtr myNetwork = INetwork::Create();
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker float weightsData[] = {1.0f}; // Identity
31*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32, 0.0f, 0, true);
32*89c4ff92SAndroid Build Coastguard Worker weightsInfo.SetConstant();
33*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(weightsInfo, weightsData);
34*89c4ff92SAndroid Build Coastguard Worker
35*89c4ff92SAndroid Build Coastguard Worker // Constant layer that now holds weights data for FullyConnected
36*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const constantWeightsLayer = myNetwork->AddConstantLayer(weights, "const weights");
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker FullyConnectedDescriptor fullyConnectedDesc;
39*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const fullyConnectedLayer = myNetwork->AddFullyConnectedLayer(fullyConnectedDesc,
40*89c4ff92SAndroid Build Coastguard Worker "fully connected");
41*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* InputLayer = myNetwork->AddInputLayer(0);
42*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* OutputLayer = myNetwork->AddOutputLayer(0);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker InputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
45*89c4ff92SAndroid Build Coastguard Worker constantWeightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
46*89c4ff92SAndroid Build Coastguard Worker fullyConnectedLayer->GetOutputSlot(0).Connect(OutputLayer->GetInputSlot(0));
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker // Create ArmNN runtime
49*89c4ff92SAndroid Build Coastguard Worker IRuntime::CreationOptions options; // default options
50*89c4ff92SAndroid Build Coastguard Worker IRuntimePtr run = IRuntime::Create(options);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker //Set the tensors in the network.
53*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
54*89c4ff92SAndroid Build Coastguard Worker InputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
57*89c4ff92SAndroid Build Coastguard Worker fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
58*89c4ff92SAndroid Build Coastguard Worker constantWeightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker // Optimise ArmNN network
61*89c4ff92SAndroid Build Coastguard Worker IOptimizedNetworkPtr optNet = Optimize(*myNetwork, {Compute::CpuRef}, run->GetDeviceSpec());
62*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker // This shouldn't happen for this simple sample, with reference backend.
65*89c4ff92SAndroid Build Coastguard Worker // But in general usage Optimize could fail if the hardware at runtime cannot
66*89c4ff92SAndroid Build Coastguard Worker // support the model that has been provided.
67*89c4ff92SAndroid Build Coastguard Worker std::cerr << "Error: Failed to optimise the input network." << std::endl;
68*89c4ff92SAndroid Build Coastguard Worker return 1;
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker // Load graph into runtime
72*89c4ff92SAndroid Build Coastguard Worker run->LoadNetwork(networkIdentifier, std::move(optNet));
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker //Creates structures for inputs and outputs.
75*89c4ff92SAndroid Build Coastguard Worker std::vector<float> inputData{number};
76*89c4ff92SAndroid Build Coastguard Worker std::vector<float> outputData(1);
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo = run->GetInputTensorInfo(networkIdentifier, 0);
79*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant(true);
80*89c4ff92SAndroid Build Coastguard Worker InputTensors inputTensors{{0, armnn::ConstTensor(inputTensorInfo,
81*89c4ff92SAndroid Build Coastguard Worker inputData.data())}};
82*89c4ff92SAndroid Build Coastguard Worker OutputTensors outputTensors{{0, armnn::Tensor(run->GetOutputTensorInfo(networkIdentifier, 0),
83*89c4ff92SAndroid Build Coastguard Worker outputData.data())}};
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker // Execute network
86*89c4ff92SAndroid Build Coastguard Worker run->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker std::cout << "Your number was " << outputData[0] << std::endl;
89*89c4ff92SAndroid Build Coastguard Worker return 0;
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker }
92