1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2020 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <assert.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <stdint.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <stdlib.h>
10*4bdc9457SAndroid Build Coastguard Worker
11*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/memory-planner.h>
12*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
13*4bdc9457SAndroid Build Coastguard Worker
14*4bdc9457SAndroid Build Coastguard Worker // Check if two xnn_value's lifecycles overlap.
value_lifecycle_overlap(const struct xnn_value_usage * a,const struct xnn_value_usage * b)15*4bdc9457SAndroid Build Coastguard Worker inline static bool value_lifecycle_overlap(const struct xnn_value_usage* a, const struct xnn_value_usage* b) {
16*4bdc9457SAndroid Build Coastguard Worker assert(a->last_node >= a->first_node);
17*4bdc9457SAndroid Build Coastguard Worker assert(b->last_node >= b->first_node);
18*4bdc9457SAndroid Build Coastguard Worker if (a->first_node < b->first_node) {
19*4bdc9457SAndroid Build Coastguard Worker return a->last_node >= b->first_node;
20*4bdc9457SAndroid Build Coastguard Worker } else {
21*4bdc9457SAndroid Build Coastguard Worker return b->last_node >= a->first_node;
22*4bdc9457SAndroid Build Coastguard Worker }
23*4bdc9457SAndroid Build Coastguard Worker }
24*4bdc9457SAndroid Build Coastguard Worker
25*4bdc9457SAndroid Build Coastguard Worker // Use this comparison function to sort xnn_value_usage according to the
26*4bdc9457SAndroid Build Coastguard Worker // tensor_size in decreasing order.
cmp_value_usage_tensor_size(const void * a,const void * b)27*4bdc9457SAndroid Build Coastguard Worker static inline int cmp_value_usage_tensor_size(const void* a, const void* b) {
28*4bdc9457SAndroid Build Coastguard Worker const size_t tensor_size_a = (*(struct xnn_value_usage *const*)a)->tensor_size;
29*4bdc9457SAndroid Build Coastguard Worker const size_t tensor_size_b = (*(struct xnn_value_usage *const*)b)->tensor_size;
30*4bdc9457SAndroid Build Coastguard Worker return (tensor_size_b > tensor_size_a) - (tensor_size_b < tensor_size_a);
31*4bdc9457SAndroid Build Coastguard Worker }
32*4bdc9457SAndroid Build Coastguard Worker
populate_value_lifecycle(const xnn_subgraph_t subgraph,struct xnn_value_usage * usage)33*4bdc9457SAndroid Build Coastguard Worker static void populate_value_lifecycle(const xnn_subgraph_t subgraph, struct xnn_value_usage* usage) {
34*4bdc9457SAndroid Build Coastguard Worker assert(subgraph != NULL);
35*4bdc9457SAndroid Build Coastguard Worker if (subgraph->num_nodes == 0) {
36*4bdc9457SAndroid Build Coastguard Worker return;
37*4bdc9457SAndroid Build Coastguard Worker }
38*4bdc9457SAndroid Build Coastguard Worker // As we initialized first/last_node in each xnn_value_usage to 0 as in 'xnn_init_value_mem_allocation_tracker',
39*4bdc9457SAndroid Build Coastguard Worker // we start with the second node to tell whether first/last_node have been set or not, and check the first node last.
40*4bdc9457SAndroid Build Coastguard Worker for (uint32_t nid = 1; nid < subgraph->num_nodes; ++nid) {
41*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* node = subgraph->nodes + nid;
42*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_inputs; ++i) {
43*4bdc9457SAndroid Build Coastguard Worker if (usage[node->inputs[i]].first_node == 0) {
44*4bdc9457SAndroid Build Coastguard Worker usage[node->inputs[i]].first_node = nid;
45*4bdc9457SAndroid Build Coastguard Worker }
46*4bdc9457SAndroid Build Coastguard Worker usage[node->inputs[i]].last_node = nid;
47*4bdc9457SAndroid Build Coastguard Worker }
48*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < node->num_outputs; ++i) {
49*4bdc9457SAndroid Build Coastguard Worker if (usage[node->outputs[i]].first_node == 0) {
50*4bdc9457SAndroid Build Coastguard Worker usage[node->outputs[i]].first_node = nid;
51*4bdc9457SAndroid Build Coastguard Worker }
52*4bdc9457SAndroid Build Coastguard Worker usage[node->outputs[i]].last_node = nid;
53*4bdc9457SAndroid Build Coastguard Worker }
54*4bdc9457SAndroid Build Coastguard Worker }
55*4bdc9457SAndroid Build Coastguard Worker const struct xnn_node* first_node = subgraph->nodes;
56*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < first_node->num_inputs; ++i) {
57*4bdc9457SAndroid Build Coastguard Worker usage[first_node->inputs[i]].first_node = 0;
58*4bdc9457SAndroid Build Coastguard Worker }
59*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = 0; i < first_node->num_outputs; ++i) {
60*4bdc9457SAndroid Build Coastguard Worker usage[first_node->outputs[i]].first_node = 0;
61*4bdc9457SAndroid Build Coastguard Worker }
62*4bdc9457SAndroid Build Coastguard Worker }
63*4bdc9457SAndroid Build Coastguard Worker
64*4bdc9457SAndroid Build Coastguard Worker // Represent a memory block [start, end)
65*4bdc9457SAndroid Build Coastguard Worker struct memory_block {
66*4bdc9457SAndroid Build Coastguard Worker size_t start;
67*4bdc9457SAndroid Build Coastguard Worker size_t end;
68*4bdc9457SAndroid Build Coastguard Worker };
69*4bdc9457SAndroid Build Coastguard Worker
70*4bdc9457SAndroid Build Coastguard Worker // Use this comparison function to sort memory_block according to the 'start'
71*4bdc9457SAndroid Build Coastguard Worker // in increasing order.
cmp_memory_block(const void * a,const void * b)72*4bdc9457SAndroid Build Coastguard Worker static inline int cmp_memory_block(const void* a, const void* b) {
73*4bdc9457SAndroid Build Coastguard Worker const size_t start_a = ((const struct memory_block*)a)->start;
74*4bdc9457SAndroid Build Coastguard Worker const size_t start_b = ((const struct memory_block*)b)->start;
75*4bdc9457SAndroid Build Coastguard Worker return (start_a > start_b) - (start_a < start_b);
76*4bdc9457SAndroid Build Coastguard Worker }
77*4bdc9457SAndroid Build Coastguard Worker
78*4bdc9457SAndroid Build Coastguard Worker // Given the current live memory blocks, return the offset in a memory arena for a to-be-allocated value of size
79*4bdc9457SAndroid Build Coastguard Worker // 'to_alloc_size'.
find_value_alloc_offset(struct memory_block * live_mem_blocks,size_t num_mem_blocks,size_t to_alloc_size)80*4bdc9457SAndroid Build Coastguard Worker static size_t find_value_alloc_offset(struct memory_block* live_mem_blocks,
81*4bdc9457SAndroid Build Coastguard Worker size_t num_mem_blocks,
82*4bdc9457SAndroid Build Coastguard Worker size_t to_alloc_size) {
83*4bdc9457SAndroid Build Coastguard Worker if (num_mem_blocks == 0) {
84*4bdc9457SAndroid Build Coastguard Worker return 0;
85*4bdc9457SAndroid Build Coastguard Worker }
86*4bdc9457SAndroid Build Coastguard Worker
87*4bdc9457SAndroid Build Coastguard Worker if (num_mem_blocks == 1) {
88*4bdc9457SAndroid Build Coastguard Worker return live_mem_blocks[0].end;
89*4bdc9457SAndroid Build Coastguard Worker }
90*4bdc9457SAndroid Build Coastguard Worker
91*4bdc9457SAndroid Build Coastguard Worker // Sort memory blocks according to 'start' in increasing order.
92*4bdc9457SAndroid Build Coastguard Worker qsort(live_mem_blocks, num_mem_blocks, sizeof(struct memory_block), cmp_memory_block);
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker // Coalesce overlapping or immediate adjacent memory blocks to form a list of non-overlapping memory blocks in order
95*4bdc9457SAndroid Build Coastguard Worker // to find the smallest gap.
96*4bdc9457SAndroid Build Coastguard Worker size_t num_coalesced_mem_blocks = 1;
97*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 1; i < num_mem_blocks; ++i) {
98*4bdc9457SAndroid Build Coastguard Worker const size_t current_coalesced_end =
99*4bdc9457SAndroid Build Coastguard Worker live_mem_blocks[num_coalesced_mem_blocks - 1].end;
100*4bdc9457SAndroid Build Coastguard Worker if (live_mem_blocks[i].start > current_coalesced_end) {
101*4bdc9457SAndroid Build Coastguard Worker assert(num_coalesced_mem_blocks <= i);
102*4bdc9457SAndroid Build Coastguard Worker live_mem_blocks[num_coalesced_mem_blocks] = live_mem_blocks[i];
103*4bdc9457SAndroid Build Coastguard Worker num_coalesced_mem_blocks++;
104*4bdc9457SAndroid Build Coastguard Worker continue;
105*4bdc9457SAndroid Build Coastguard Worker }
106*4bdc9457SAndroid Build Coastguard Worker if (live_mem_blocks[i].end > current_coalesced_end) {
107*4bdc9457SAndroid Build Coastguard Worker live_mem_blocks[num_coalesced_mem_blocks - 1].end = live_mem_blocks[i].end;
108*4bdc9457SAndroid Build Coastguard Worker }
109*4bdc9457SAndroid Build Coastguard Worker }
110*4bdc9457SAndroid Build Coastguard Worker
111*4bdc9457SAndroid Build Coastguard Worker size_t smallest_gap_size = SIZE_MAX;
112*4bdc9457SAndroid Build Coastguard Worker // The first index to live_mem_blocks that the 'to_alloc_size' should be allocated after.
113*4bdc9457SAndroid Build Coastguard Worker size_t smallest_gap_index = num_coalesced_mem_blocks - 1;
114*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_coalesced_mem_blocks - 1; ++i) {
115*4bdc9457SAndroid Build Coastguard Worker assert(live_mem_blocks[i + 1].start > live_mem_blocks[i].end);
116*4bdc9457SAndroid Build Coastguard Worker const size_t gap = live_mem_blocks[i + 1].start - live_mem_blocks[i].end;
117*4bdc9457SAndroid Build Coastguard Worker if (gap >= to_alloc_size && gap < smallest_gap_size) {
118*4bdc9457SAndroid Build Coastguard Worker smallest_gap_index = i;
119*4bdc9457SAndroid Build Coastguard Worker smallest_gap_size = gap;
120*4bdc9457SAndroid Build Coastguard Worker }
121*4bdc9457SAndroid Build Coastguard Worker }
122*4bdc9457SAndroid Build Coastguard Worker return live_mem_blocks[smallest_gap_index].end;
123*4bdc9457SAndroid Build Coastguard Worker }
124*4bdc9457SAndroid Build Coastguard Worker
xnn_init_value_allocation_tracker(struct xnn_value_allocation_tracker * tracker,const xnn_subgraph_t subgraph)125*4bdc9457SAndroid Build Coastguard Worker void xnn_init_value_allocation_tracker(struct xnn_value_allocation_tracker* tracker, const xnn_subgraph_t subgraph) {
126*4bdc9457SAndroid Build Coastguard Worker tracker->subgraph = subgraph;
127*4bdc9457SAndroid Build Coastguard Worker tracker->mem_arena_size = 0;
128*4bdc9457SAndroid Build Coastguard Worker tracker->usage = xnn_allocate_zero_memory(sizeof(struct xnn_value_usage) * subgraph->num_values);
129*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_MEMOPT
130*4bdc9457SAndroid Build Coastguard Worker populate_value_lifecycle(tracker->subgraph, tracker->usage);
131*4bdc9457SAndroid Build Coastguard Worker #endif
132*4bdc9457SAndroid Build Coastguard Worker tracker->min_value_id = XNN_INVALID_VALUE_ID;
133*4bdc9457SAndroid Build Coastguard Worker tracker->max_value_id = XNN_INVALID_VALUE_ID;
134*4bdc9457SAndroid Build Coastguard Worker }
135*4bdc9457SAndroid Build Coastguard Worker
xnn_add_value_allocation_tracker(struct xnn_value_allocation_tracker * tracker,uint32_t value_id,size_t tensor_size)136*4bdc9457SAndroid Build Coastguard Worker void xnn_add_value_allocation_tracker(struct xnn_value_allocation_tracker* tracker,
137*4bdc9457SAndroid Build Coastguard Worker uint32_t value_id,
138*4bdc9457SAndroid Build Coastguard Worker size_t tensor_size) {
139*4bdc9457SAndroid Build Coastguard Worker tracker->usage[value_id].tensor_size = tensor_size;
140*4bdc9457SAndroid Build Coastguard Worker if (tracker->min_value_id == XNN_INVALID_VALUE_ID) {
141*4bdc9457SAndroid Build Coastguard Worker tracker->min_value_id = value_id;
142*4bdc9457SAndroid Build Coastguard Worker } else {
143*4bdc9457SAndroid Build Coastguard Worker // Note that values are expected to be added in increasing order.
144*4bdc9457SAndroid Build Coastguard Worker assert(value_id > tracker->min_value_id);
145*4bdc9457SAndroid Build Coastguard Worker assert(value_id > tracker->max_value_id);
146*4bdc9457SAndroid Build Coastguard Worker }
147*4bdc9457SAndroid Build Coastguard Worker
148*4bdc9457SAndroid Build Coastguard Worker tracker->max_value_id = value_id;
149*4bdc9457SAndroid Build Coastguard Worker }
150*4bdc9457SAndroid Build Coastguard Worker
xnn_plan_value_allocation_tracker(struct xnn_value_allocation_tracker * tracker)151*4bdc9457SAndroid Build Coastguard Worker void xnn_plan_value_allocation_tracker(struct xnn_value_allocation_tracker* tracker) {
152*4bdc9457SAndroid Build Coastguard Worker #if XNN_ENABLE_MEMOPT
153*4bdc9457SAndroid Build Coastguard Worker if (tracker->min_value_id == XNN_INVALID_VALUE_ID) {
154*4bdc9457SAndroid Build Coastguard Worker assert(tracker->max_value_id == XNN_INVALID_VALUE_ID);
155*4bdc9457SAndroid Build Coastguard Worker return;
156*4bdc9457SAndroid Build Coastguard Worker }
157*4bdc9457SAndroid Build Coastguard Worker
158*4bdc9457SAndroid Build Coastguard Worker const uint32_t num_values = tracker->max_value_id - tracker->min_value_id + 1;
159*4bdc9457SAndroid Build Coastguard Worker struct xnn_value_usage** sorted_usage = xnn_allocate_zero_memory(sizeof(struct xnn_value_usage*) * num_values);
160*4bdc9457SAndroid Build Coastguard Worker size_t num_values_to_alloc = 0;
161*4bdc9457SAndroid Build Coastguard Worker for (size_t i = tracker->min_value_id; i <= tracker->max_value_id; ++i) {
162*4bdc9457SAndroid Build Coastguard Worker struct xnn_value_usage* info = tracker->usage + i;
163*4bdc9457SAndroid Build Coastguard Worker if (info->tensor_size != 0) {
164*4bdc9457SAndroid Build Coastguard Worker sorted_usage[num_values_to_alloc++] = info;
165*4bdc9457SAndroid Build Coastguard Worker }
166*4bdc9457SAndroid Build Coastguard Worker }
167*4bdc9457SAndroid Build Coastguard Worker qsort(sorted_usage, num_values_to_alloc, sizeof(struct xnn_value_usage*), cmp_value_usage_tensor_size);
168*4bdc9457SAndroid Build Coastguard Worker
169*4bdc9457SAndroid Build Coastguard Worker // Start the allocation planning process.
170*4bdc9457SAndroid Build Coastguard Worker struct memory_block* current_live_mem_blocks = xnn_allocate_zero_memory(
171*4bdc9457SAndroid Build Coastguard Worker sizeof(struct memory_block) * num_values_to_alloc);
172*4bdc9457SAndroid Build Coastguard Worker size_t mem_arena_size = 0;
173*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < num_values_to_alloc; ++i) {
174*4bdc9457SAndroid Build Coastguard Worker size_t num_live_mem_blocks = 0;
175*4bdc9457SAndroid Build Coastguard Worker struct xnn_value_usage* current = sorted_usage[i];
176*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < i; ++j) {
177*4bdc9457SAndroid Build Coastguard Worker const struct xnn_value_usage* allocated = sorted_usage[j];
178*4bdc9457SAndroid Build Coastguard Worker if (value_lifecycle_overlap(current, allocated)) {
179*4bdc9457SAndroid Build Coastguard Worker current_live_mem_blocks[num_live_mem_blocks++] = (struct memory_block){
180*4bdc9457SAndroid Build Coastguard Worker .start = allocated->alloc_offset,
181*4bdc9457SAndroid Build Coastguard Worker .end = allocated->alloc_offset + allocated->tensor_size,
182*4bdc9457SAndroid Build Coastguard Worker };
183*4bdc9457SAndroid Build Coastguard Worker }
184*4bdc9457SAndroid Build Coastguard Worker }
185*4bdc9457SAndroid Build Coastguard Worker current->alloc_offset = find_value_alloc_offset(current_live_mem_blocks, num_live_mem_blocks, current->tensor_size);
186*4bdc9457SAndroid Build Coastguard Worker if (mem_arena_size < current->alloc_offset + current->tensor_size) {
187*4bdc9457SAndroid Build Coastguard Worker mem_arena_size = current->alloc_offset + current->tensor_size;
188*4bdc9457SAndroid Build Coastguard Worker }
189*4bdc9457SAndroid Build Coastguard Worker }
190*4bdc9457SAndroid Build Coastguard Worker
191*4bdc9457SAndroid Build Coastguard Worker tracker->mem_arena_size = mem_arena_size;
192*4bdc9457SAndroid Build Coastguard Worker xnn_release_memory(sorted_usage);
193*4bdc9457SAndroid Build Coastguard Worker xnn_release_memory(current_live_mem_blocks);
194*4bdc9457SAndroid Build Coastguard Worker #else
195*4bdc9457SAndroid Build Coastguard Worker tracker->mem_arena_size = 0;
196*4bdc9457SAndroid Build Coastguard Worker for (uint32_t i = tracker->min_value_id; i <= tracker->max_value_id; ++i) {
197*4bdc9457SAndroid Build Coastguard Worker if (tracker->usage[i].tensor_size > 0) {
198*4bdc9457SAndroid Build Coastguard Worker tracker->usage[i].alloc_offset = tracker->mem_arena_size;
199*4bdc9457SAndroid Build Coastguard Worker tracker->mem_arena_size += tracker->usage[i].tensor_size;
200*4bdc9457SAndroid Build Coastguard Worker }
201*4bdc9457SAndroid Build Coastguard Worker }
202*4bdc9457SAndroid Build Coastguard Worker #endif
203*4bdc9457SAndroid Build Coastguard Worker }
204