xref: /aosp_15_r20/external/executorch/backends/mediatek/runtime/include/NeuronBufferAllocator.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include "NeuronExecutor.h"
12 #include "NeuronLog.h"
13 #include "api/NeuronAdapter.h"
14 
15 #include <android/hardware_buffer.h>
16 
17 #include <executorch/runtime/core/memory_allocator.h>
18 
19 #include <map>
20 #include <memory>
21 #include <mutex>
22 #include <new>
23 
24 #define GET_NEURON_ALLOCATOR \
25   ::torch::executor::neuron::BufferAllocator::GetInstance()
26 
27 // TODO: Move this code to the executorch::backends::neuron namespace.
28 // The torch:: namespace is deprecated for ExecuTorch code.
29 namespace torch {
30 namespace executor {
31 namespace neuron {
32 
33 struct BufferDeleter {
operatorBufferDeleter34   void operator()(AHardwareBuffer* buffer) {
35     if (buffer != nullptr) {
36       AHardwareBuffer_unlock(buffer, nullptr);
37       AHardwareBuffer_release(buffer);
38     }
39   }
40 };
41 
42 class MemoryUnit {
43  public:
Create(size_t size)44   static std::unique_ptr<MemoryUnit> Create(size_t size) {
45     auto obj = std::unique_ptr<MemoryUnit>(new (std::nothrow) MemoryUnit(size));
46     return (obj && (obj->Allocate() == NEURON_NO_ERROR)) ? std::move(obj)
47                                                          : nullptr;
48   }
49 
~MemoryUnit()50   ~MemoryUnit() {
51     mNeuronMemory.reset();
52     mAhwb.reset();
53   }
54 
GetSize()55   size_t GetSize() const {
56     return mSize;
57   }
58 
GetAddress()59   void* GetAddress() const {
60     return mAddress;
61   }
62 
GetNeuronMemory()63   NeuronMemory* GetNeuronMemory() const {
64     return mNeuronMemory.get();
65   }
66 
67  private:
MemoryUnit(size_t size)68   explicit MemoryUnit(size_t size) : mSize(size) {}
69 
Allocate()70   int Allocate() {
71     AHardwareBuffer_Desc iDesc{
72         .width = static_cast<uint32_t>(mSize),
73         .height = 1,
74         .layers = 1,
75         .format = AHARDWAREBUFFER_FORMAT_BLOB,
76         .usage = mAhwbType,
77         .stride = static_cast<uint32_t>(mSize),
78     };
79     AHardwareBuffer* Abuffer = nullptr;
80     AHardwareBuffer_allocate(&iDesc, &Abuffer);
81     CHECK_VALID_PTR(Abuffer);
82     mAhwb = std::unique_ptr<AHardwareBuffer, BufferDeleter>(Abuffer);
83 
84     NeuronMemory* memory = nullptr;
85     NeuronMemory_createFromAHardwareBuffer(Abuffer, &memory);
86     CHECK_VALID_PTR(memory);
87     mNeuronMemory = std::
88         unique_ptr<NeuronMemory, executorch::backends::neuron::NeuronDeleter>(
89             memory);
90 
91     AHardwareBuffer_lock(Abuffer, mAhwbType, -1, nullptr, &mAddress);
92     CHECK_VALID_PTR(mAddress);
93     return NEURON_NO_ERROR;
94   }
95 
96  private:
97   std::unique_ptr<NeuronMemory, executorch::backends::neuron::NeuronDeleter>
98       mNeuronMemory;
99 
100   std::unique_ptr<AHardwareBuffer, BufferDeleter> mAhwb;
101 
102   uint64_t mAhwbType = AHARDWAREBUFFER_USAGE_CPU_READ_OFTEN |
103       AHARDWAREBUFFER_USAGE_CPU_WRITE_OFTEN;
104 
105   void* mAddress = nullptr;
106 
107   size_t mSize = 0;
108 };
109 
110 class BufferAllocator : public executorch::runtime::MemoryAllocator {
111  public:
112   static BufferAllocator& GetInstance();
113 
114   void* Allocate(size_t size);
115 
116   void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
117     return Allocate(size);
118   }
119 
120   bool RemoveBuffer(void* address);
121 
122   const MemoryUnit* Find(void* address);
123 
124   void Clear();
125 
126  private:
BufferAllocator()127   BufferAllocator() : executorch::runtime::MemoryAllocator(0, nullptr) {}
128 
129   BufferAllocator(const BufferAllocator&) = delete;
130 
131   BufferAllocator& operator=(const BufferAllocator&) = delete;
132 
~BufferAllocator()133   ~BufferAllocator() override {
134     Clear();
135   }
136 
137  private:
138   std::map<void*, std::unique_ptr<MemoryUnit>> mPool;
139 
140   std::mutex mMutex;
141 };
142 
143 } // namespace neuron
144 } // namespace executor
145 } // namespace torch
146