xref: /aosp_15_r20/external/armnn/samples/CustomMemoryAllocatorSample.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ICustomAllocator.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/CL/CLKernelLibrary.h>
10*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLScheduler.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker /** Sample implementation of ICustomAllocator for use with the ClBackend.
15*89c4ff92SAndroid Build Coastguard Worker  *  Note: any memory allocated must be host addressable with write access
16*89c4ff92SAndroid Build Coastguard Worker  *  in order for ArmNN to be able to properly use it. */
17*89c4ff92SAndroid Build Coastguard Worker class SampleClBackendCustomAllocator : public armnn::ICustomAllocator
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker public:
20*89c4ff92SAndroid Build Coastguard Worker     SampleClBackendCustomAllocator() = default;
21*89c4ff92SAndroid Build Coastguard Worker 
allocate(size_t size,size_t alignment)22*89c4ff92SAndroid Build Coastguard Worker     void* allocate(size_t size, size_t alignment) override
23*89c4ff92SAndroid Build Coastguard Worker     {
24*89c4ff92SAndroid Build Coastguard Worker         // If alignment is 0 just use the CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE for alignment
25*89c4ff92SAndroid Build Coastguard Worker         if (alignment == 0)
26*89c4ff92SAndroid Build Coastguard Worker         {
27*89c4ff92SAndroid Build Coastguard Worker             alignment = arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
28*89c4ff92SAndroid Build Coastguard Worker         }
29*89c4ff92SAndroid Build Coastguard Worker         size_t space = size + alignment + alignment;
30*89c4ff92SAndroid Build Coastguard Worker         auto allocatedMemPtr = std::malloc(space * sizeof(size_t));
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker         if (std::align(alignment, size, allocatedMemPtr, space) == nullptr)
33*89c4ff92SAndroid Build Coastguard Worker         {
34*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("SampleClBackendCustomAllocator::Alignment failed");
35*89c4ff92SAndroid Build Coastguard Worker         }
36*89c4ff92SAndroid Build Coastguard Worker         return allocatedMemPtr;
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker 
free(void * ptr)39*89c4ff92SAndroid Build Coastguard Worker     void free(void* ptr) override
40*89c4ff92SAndroid Build Coastguard Worker     {
41*89c4ff92SAndroid Build Coastguard Worker         std::free(ptr);
42*89c4ff92SAndroid Build Coastguard Worker     }
43*89c4ff92SAndroid Build Coastguard Worker 
GetMemorySourceType()44*89c4ff92SAndroid Build Coastguard Worker     armnn::MemorySource GetMemorySourceType() override
45*89c4ff92SAndroid Build Coastguard Worker     {
46*89c4ff92SAndroid Build Coastguard Worker         return armnn::MemorySource::Malloc;
47*89c4ff92SAndroid Build Coastguard Worker     }
48*89c4ff92SAndroid Build Coastguard Worker };
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker // A simple example application to show the usage of a custom memory allocator. In this sample, the users single
52*89c4ff92SAndroid Build Coastguard Worker // input number is multiplied by 1.0f using a fully connected layer with a single neuron to produce an output
53*89c4ff92SAndroid Build Coastguard Worker // number that is the same as the input. All memory required to execute this mini network is allocated with
54*89c4ff92SAndroid Build Coastguard Worker // the provided custom allocator.
55*89c4ff92SAndroid Build Coastguard Worker //
56*89c4ff92SAndroid Build Coastguard Worker // Using a Custom Allocator is required for use with Protected Mode and Protected Memory.
57*89c4ff92SAndroid Build Coastguard Worker // This example is provided using only unprotected malloc as Protected Memory is platform
58*89c4ff92SAndroid Build Coastguard Worker // and implementation specific.
59*89c4ff92SAndroid Build Coastguard Worker //
60*89c4ff92SAndroid Build Coastguard Worker // Note: This example is similar to the SimpleSample application that can also be found in armnn/samples.
61*89c4ff92SAndroid Build Coastguard Worker //       The differences are in the use of a custom allocator, the backend is GpuAcc, and the inputs/outputs
62*89c4ff92SAndroid Build Coastguard Worker //       are being imported instead of copied. (Import must be enabled when using a Custom Allocator)
63*89c4ff92SAndroid Build Coastguard Worker //       You might find this useful for comparison.
main()64*89c4ff92SAndroid Build Coastguard Worker int main()
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     float number;
69*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Please enter a number: " << std::endl;
70*89c4ff92SAndroid Build Coastguard Worker     std::cin >> number;
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     // Turn on logging to standard output
73*89c4ff92SAndroid Build Coastguard Worker     // This is useful in this sample so that users can learn more about what is going on
74*89c4ff92SAndroid Build Coastguard Worker     ConfigureLogging(true, false, LogSeverity::Info);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     // Construct ArmNN network
77*89c4ff92SAndroid Build Coastguard Worker     NetworkId networkIdentifier;
78*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr network = INetwork::Create();
79*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor fullyConnectedDesc;
80*89c4ff92SAndroid Build Coastguard Worker     float weightsData[] = {1.0f}; // Identity
81*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightsInfo(TensorShape({1, 1}), DataType::Float32, 0.0f, 0, true);
82*89c4ff92SAndroid Build Coastguard Worker     weightsInfo.SetConstant(true);
83*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(weightsInfo, weightsData);
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* inputLayer   = network->AddInputLayer(0);
86*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* weightsLayer = network->AddConstantLayer(weights, "Weights");
87*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* fullyConnectedLayer =
88*89c4ff92SAndroid Build Coastguard Worker             network->AddFullyConnectedLayer(fullyConnectedDesc, "fully connected");
89*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* outputLayer  = network->AddOutputLayer(0);
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(0));
92*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(fullyConnectedLayer->GetInputSlot(1));
93*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
94*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsInfo);
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     // Create ArmNN runtime:
97*89c4ff92SAndroid Build Coastguard Worker     //
98*89c4ff92SAndroid Build Coastguard Worker     // This is the interesting bit when executing a model with a custom allocator.
99*89c4ff92SAndroid Build Coastguard Worker     // You can have different allocators for different backends. To support this
100*89c4ff92SAndroid Build Coastguard Worker     // the runtime creation option has a map that takes a BackendId and the corresponding
101*89c4ff92SAndroid Build Coastguard Worker     // allocator that should be used for that backend.
102*89c4ff92SAndroid Build Coastguard Worker     // Only GpuAcc supports a Custom Allocator for now
103*89c4ff92SAndroid Build Coastguard Worker     //
104*89c4ff92SAndroid Build Coastguard Worker     // Note: This is not covered in this example but if you want to run a model on
105*89c4ff92SAndroid Build Coastguard Worker     //       protected memory a custom allocator needs to be provided that supports
106*89c4ff92SAndroid Build Coastguard Worker     //       protected memory allocations and the MemorySource of that allocator is
107*89c4ff92SAndroid Build Coastguard Worker     //       set to MemorySource::DmaBufProtected
108*89c4ff92SAndroid Build Coastguard Worker     IRuntime::CreationOptions options;
109*89c4ff92SAndroid Build Coastguard Worker     auto customAllocator = std::make_shared<SampleClBackendCustomAllocator>();
110*89c4ff92SAndroid Build Coastguard Worker     options.m_CustomAllocatorMap = {{"GpuAcc", std::move(customAllocator)}};
111*89c4ff92SAndroid Build Coastguard Worker     IRuntimePtr runtime = IRuntime::Create(options);
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker     //Set the tensors in the network.
114*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(TensorShape({1, 1}), DataType::Float32);
115*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker     unsigned int numElements = inputTensorInfo.GetNumElements();
118*89c4ff92SAndroid Build Coastguard Worker     size_t totalBytes = numElements * sizeof(float);
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(TensorShape({1, 1}), DataType::Float32);
121*89c4ff92SAndroid Build Coastguard Worker     fullyConnectedLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     // Optimise ArmNN network
124*89c4ff92SAndroid Build Coastguard Worker     OptimizerOptionsOpaque optOptions;
125*89c4ff92SAndroid Build Coastguard Worker     optOptions.SetImportEnabled(true);
126*89c4ff92SAndroid Build Coastguard Worker     IOptimizedNetworkPtr optNet =
127*89c4ff92SAndroid Build Coastguard Worker                 Optimize(*network, {"GpuAcc"}, runtime->GetDeviceSpec(), optOptions);
128*89c4ff92SAndroid Build Coastguard Worker     if (!optNet)
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         // This shouldn't happen for this simple sample, with GpuAcc backend.
131*89c4ff92SAndroid Build Coastguard Worker         // But in general usage Optimize could fail if the backend at runtime cannot
132*89c4ff92SAndroid Build Coastguard Worker         // support the model that has been provided.
133*89c4ff92SAndroid Build Coastguard Worker         std::cerr << "Error: Failed to optimise the input network." << std::endl;
134*89c4ff92SAndroid Build Coastguard Worker         return 1;
135*89c4ff92SAndroid Build Coastguard Worker     }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker     // Load graph into runtime
138*89c4ff92SAndroid Build Coastguard Worker     std::string ignoredErrorMessage;
139*89c4ff92SAndroid Build Coastguard Worker     INetworkProperties networkProperties(false, MemorySource::Malloc, MemorySource::Malloc);
140*89c4ff92SAndroid Build Coastguard Worker     runtime->LoadNetwork(networkIdentifier, std::move(optNet), ignoredErrorMessage, networkProperties);
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker     // Creates structures for input & output
143*89c4ff92SAndroid Build Coastguard Worker     const size_t alignment =
144*89c4ff92SAndroid Build Coastguard Worker             arm_compute::CLKernelLibrary::get().get_device().getInfo<CL_DEVICE_GLOBAL_MEM_CACHELINE_SIZE>();
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     void* alignedInputPtr = options.m_CustomAllocatorMap["GpuAcc"]->allocate(totalBytes, alignment);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     // Input with negative values
149*89c4ff92SAndroid Build Coastguard Worker     auto* inputPtr = reinterpret_cast<float*>(alignedInputPtr);
150*89c4ff92SAndroid Build Coastguard Worker     std::fill_n(inputPtr, numElements, number);
151*89c4ff92SAndroid Build Coastguard Worker 
152*89c4ff92SAndroid Build Coastguard Worker     void* alignedOutputPtr = options.m_CustomAllocatorMap["GpuAcc"]->allocate(totalBytes, alignment);
153*89c4ff92SAndroid Build Coastguard Worker     auto* outputPtr = reinterpret_cast<float*>(alignedOutputPtr);
154*89c4ff92SAndroid Build Coastguard Worker     std::fill_n(outputPtr, numElements, -10.0f);
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo = runtime->GetInputTensorInfo(networkIdentifier, 0);
157*89c4ff92SAndroid Build Coastguard Worker     inputTensorInfo.SetConstant(true);
158*89c4ff92SAndroid Build Coastguard Worker     InputTensors inputTensors
159*89c4ff92SAndroid Build Coastguard Worker     {
160*89c4ff92SAndroid Build Coastguard Worker         {0, ConstTensor(inputTensorInfo, alignedInputPtr)},
161*89c4ff92SAndroid Build Coastguard Worker     };
162*89c4ff92SAndroid Build Coastguard Worker     OutputTensors outputTensors
163*89c4ff92SAndroid Build Coastguard Worker     {
164*89c4ff92SAndroid Build Coastguard Worker         {0, Tensor(runtime->GetOutputTensorInfo(networkIdentifier, 0), alignedOutputPtr)}
165*89c4ff92SAndroid Build Coastguard Worker     };
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     // Execute network
168*89c4ff92SAndroid Build Coastguard Worker     runtime->EnqueueWorkload(networkIdentifier, inputTensors, outputTensors);
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker     // Tell the CLBackend to sync memory so we can read the output.
171*89c4ff92SAndroid Build Coastguard Worker     arm_compute::CLScheduler::get().sync();
172*89c4ff92SAndroid Build Coastguard Worker     auto* outputResult = reinterpret_cast<float*>(alignedOutputPtr);
173*89c4ff92SAndroid Build Coastguard Worker     std::cout << "Your number was " << outputResult[0] << std::endl;
174*89c4ff92SAndroid Build Coastguard Worker     runtime->UnloadNetwork(networkIdentifier);
175*89c4ff92SAndroid Build Coastguard Worker     return 0;
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker }
178