xref: /aosp_15_r20/external/executorch/runtime/executor/platform_memory_allocator.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <stdio.h>
12 #include <cinttypes>
13 #include <cstdint>
14 
15 #include <executorch/runtime/core/memory_allocator.h>
16 #include <executorch/runtime/platform/log.h>
17 #include <executorch/runtime/platform/platform.h>
18 
19 namespace executorch {
20 namespace runtime {
21 namespace internal {
22 
23 /**
24  * PlatformMemoryAllocator is a memory allocator that uses a linked list to
25  * manage allocated nodes. It overrides the allocate method of MemoryAllocator
26  * using the PAL fallback allocator method `et_pal_allocate`.
27  */
28 class PlatformMemoryAllocator final : public MemoryAllocator {
29  private:
30   // We allocate a little more than requested and use that memory as a node in
31   // a linked list, pushing the allocated buffers onto a list that's iterated
32   // and freed when the KernelRuntimeContext is destroyed.
33   struct AllocationNode {
34     void* data;
35     AllocationNode* next;
36   };
37 
38   AllocationNode* head_ = nullptr;
39 
40  public:
PlatformMemoryAllocator()41   PlatformMemoryAllocator() : MemoryAllocator(0, nullptr) {}
42 
43   void* allocate(size_t size, size_t alignment = kDefaultAlignment) override {
44     if (!isPowerOf2(alignment)) {
45       ET_LOG(Error, "Alignment %zu is not a power of 2", alignment);
46       return nullptr;
47     }
48 
49     // Allocate enough memory for the node, the data and the alignment bump.
50     size_t alloc_size = sizeof(AllocationNode) + size + alignment;
51     void* node_memory = et_pal_allocate(alloc_size);
52 
53     // If allocation failed, log message and return nullptr.
54     if (node_memory == nullptr) {
55       ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size);
56       return nullptr;
57     }
58 
59     // Compute data pointer.
60     uint8_t* data_ptr =
61         reinterpret_cast<uint8_t*>(node_memory) + sizeof(AllocationNode);
62 
63     // Align the data pointer.
64     void* aligned_data_ptr = alignPointer(data_ptr, alignment);
65 
66     // Assert that the alignment didn't overflow the allocated memory.
67     ET_DCHECK_MSG(
68         reinterpret_cast<uintptr_t>(aligned_data_ptr) + size <=
69             reinterpret_cast<uintptr_t>(node_memory) + alloc_size,
70         "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu",
71         aligned_data_ptr,
72         size,
73         node_memory,
74         alloc_size);
75 
76     // Construct the node.
77     AllocationNode* new_node = reinterpret_cast<AllocationNode*>(node_memory);
78     new_node->data = aligned_data_ptr;
79     new_node->next = head_;
80     head_ = new_node;
81 
82     // Return the aligned data pointer.
83     return head_->data;
84   }
85 
reset()86   void reset() override {
87     AllocationNode* current = head_;
88     while (current != nullptr) {
89       AllocationNode* next = current->next;
90       et_pal_free(current);
91       current = next;
92     }
93     head_ = nullptr;
94   }
95 
~PlatformMemoryAllocator()96   ~PlatformMemoryAllocator() override {
97     reset();
98   }
99 
100  private:
101   // Disable copy and move.
102   PlatformMemoryAllocator(const PlatformMemoryAllocator&) = delete;
103   PlatformMemoryAllocator& operator=(const PlatformMemoryAllocator&) = delete;
104   PlatformMemoryAllocator(PlatformMemoryAllocator&&) noexcept = delete;
105   PlatformMemoryAllocator& operator=(PlatformMemoryAllocator&&) noexcept =
106       delete;
107 };
108 
109 } // namespace internal
110 } // namespace runtime
111 } // namespace executorch
112