xref: /aosp_15_r20/external/ComputeLibrary/src/graph/detail/ExecutionHelpers.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/detail/ExecutionHelpers.h"
25 
26 #include "arm_compute/graph/Graph.h"
27 #include "arm_compute/graph/GraphContext.h"
28 #include "arm_compute/graph/GraphManager.h"
29 #include "arm_compute/graph/Tensor.h"
30 #include "arm_compute/graph/Utils.h"
31 #include "arm_compute/graph/backends/BackendRegistry.h"
32 
33 namespace arm_compute
34 {
35 namespace graph
36 {
37 namespace detail
38 {
validate_all_nodes(Graph & g)39 void validate_all_nodes(Graph &g)
40 {
41     auto &nodes = g.nodes();
42 
43     // Create tasks
44     for(auto &node : nodes)
45     {
46         if(node != nullptr)
47         {
48             Target                    assigned_target = node->assigned_target();
49             backends::IDeviceBackend &backend         = backends::BackendRegistry::get().get_backend(assigned_target);
50             Status                    status          = backend.validate_node(*node);
51             ARM_COMPUTE_ERROR_ON_MSG(!bool(status), status.error_description().c_str());
52         }
53     }
54 }
55 
configure_all_tensors(Graph & g)56 void configure_all_tensors(Graph &g)
57 {
58     auto &tensors = g.tensors();
59 
60     for(auto &tensor : tensors)
61     {
62         if(tensor && tensor->handle() == nullptr)
63         {
64             Target                         target  = tensor->desc().target;
65             backends::IDeviceBackend      &backend = backends::BackendRegistry::get().get_backend(target);
66             std::unique_ptr<ITensorHandle> handle  = backend.create_tensor(*tensor);
67             ARM_COMPUTE_ERROR_ON_MSG(!handle, "Couldn't create backend handle!");
68             tensor->set_handle(std::move(handle));
69         }
70     }
71 }
72 
allocate_all_input_tensors(INode & node)73 void allocate_all_input_tensors(INode &node)
74 {
75     for(unsigned int i = 0; i < node.num_inputs(); ++i)
76     {
77         Tensor *tensor = node.input(i);
78         if(tensor != nullptr && !tensor->bound_edges().empty())
79         {
80             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
81             tensor->handle()->allocate();
82         }
83     }
84 }
85 
allocate_all_output_tensors(INode & node)86 void allocate_all_output_tensors(INode &node)
87 {
88     for(unsigned int i = 0; i < node.num_outputs(); ++i)
89     {
90         Tensor *tensor = node.output(i);
91         if(tensor != nullptr && !tensor->bound_edges().empty())
92         {
93             ARM_COMPUTE_ERROR_ON_MSG(!tensor->handle(), "Tensor handle is not configured!");
94             tensor->handle()->allocate();
95         }
96     }
97 }
98 
allocate_const_tensors(Graph & g)99 void allocate_const_tensors(Graph &g)
100 {
101     for(auto &node : g.nodes())
102     {
103         if(node != nullptr)
104         {
105             switch(node->type())
106             {
107                 case NodeType::Const:
108                 case NodeType::Input:
109                     allocate_all_output_tensors(*node);
110                     break;
111                 case NodeType::Output:
112                     allocate_all_input_tensors(*node);
113                 default:
114                     break;
115             }
116         }
117     }
118 }
119 
allocate_all_tensors(Graph & g)120 void allocate_all_tensors(Graph &g)
121 {
122     auto &tensors = g.tensors();
123 
124     for(auto &tensor : tensors)
125     {
126         if(tensor && !tensor->bound_edges().empty() && tensor->handle() != nullptr && tensor->handle()->tensor().info()->is_resizable() && tensor->handle()->tensor().is_used())
127         {
128             tensor->handle()->allocate();
129         }
130     }
131 }
132 
configure_all_nodes(Graph & g,GraphContext & ctx,const std::vector<NodeID> & node_order)133 ExecutionWorkload configure_all_nodes(Graph &g, GraphContext &ctx, const std::vector<NodeID> &node_order)
134 {
135     ExecutionWorkload workload;
136     workload.graph = &g;
137     workload.ctx   = &ctx;
138 
139     // Reserve memory for tasks
140     workload.tasks.reserve(node_order.size());
141 
142     // Create tasks
143     for(auto &node_id : node_order)
144     {
145         auto node = g.node(node_id);
146         if(node != nullptr)
147         {
148             Target                     assigned_target = node->assigned_target();
149             backends::IDeviceBackend &backend         = backends::BackendRegistry::get().get_backend(assigned_target);
150             std::unique_ptr<IFunction> func            = backend.configure_node(*node, ctx);
151             if(func != nullptr || is_utility_node(node))
152             {
153                 workload.tasks.emplace_back(ExecutionTask(std::move(func), node));
154             }
155         }
156     }
157 
158     // Add inputs and outputs
159     for(auto &node : g.nodes())
160     {
161         if(node != nullptr && node->type() == NodeType::Input)
162         {
163             workload.inputs.push_back(node->output(0));
164         }
165 
166         if(node != nullptr && node->type() == NodeType::Output)
167         {
168             workload.outputs.push_back(node->input(0));
169             continue;
170         }
171     }
172 
173     return workload;
174 }
175 
release_unused_tensors(Graph & g)176 void release_unused_tensors(Graph &g)
177 {
178     for(auto &tensor : g.tensors())
179     {
180         if(tensor != nullptr && tensor->handle() != nullptr)
181         {
182             tensor->handle()->release_if_unused();
183         }
184     }
185 }
186 
call_tensor_accessor(Tensor * tensor)187 void call_tensor_accessor(Tensor *tensor)
188 {
189     ARM_COMPUTE_ERROR_ON(!tensor);
190     tensor->call_accessor();
191 }
192 
call_all_const_node_accessors(Graph & g)193 void call_all_const_node_accessors(Graph &g)
194 {
195     auto &nodes = g.nodes();
196 
197     for(auto &node : nodes)
198     {
199         if(node != nullptr && node->type() == NodeType::Const && node->num_outputs())
200         {
201             if(!node->output(0)->bound_edges().empty())
202             {
203                 call_tensor_accessor(node->output(0));
204             }
205         }
206     }
207 }
208 
call_all_input_node_accessors(ExecutionWorkload & workload)209 bool call_all_input_node_accessors(ExecutionWorkload &workload)
210 {
211     bool is_valid = true;
212     std::for_each(std::begin(workload.inputs), std::end(workload.inputs), [&](Tensor * input_tensor)
213     {
214         bool valid_input = (input_tensor != nullptr) && input_tensor->call_accessor();
215         is_valid         = is_valid && valid_input;
216     });
217     return is_valid;
218 }
219 
prepare_all_tasks(ExecutionWorkload & workload)220 void prepare_all_tasks(ExecutionWorkload &workload)
221 {
222     ARM_COMPUTE_ERROR_ON(workload.graph == nullptr);
223     for(auto &task : workload.tasks)
224     {
225         task.prepare();
226         release_unused_tensors(*workload.graph);
227     }
228 }
229 
call_all_tasks(ExecutionWorkload & workload)230 void call_all_tasks(ExecutionWorkload &workload)
231 {
232     ARM_COMPUTE_ERROR_ON(workload.ctx == nullptr);
233 
234     // Acquire memory for the transition buffers
235     for(auto &mm_ctx : workload.ctx->memory_managers())
236     {
237         if(mm_ctx.second.cross_group != nullptr)
238         {
239             mm_ctx.second.cross_group->acquire();
240         }
241     }
242 
243     // Execute tasks
244     for(auto &task : workload.tasks)
245     {
246         task();
247     }
248 
249     // Release memory for the transition buffers
250     for(auto &mm_ctx : workload.ctx->memory_managers())
251     {
252         if(mm_ctx.second.cross_group != nullptr)
253         {
254             mm_ctx.second.cross_group->release();
255         }
256     }
257 }
258 
call_all_output_node_accessors(ExecutionWorkload & workload)259 bool call_all_output_node_accessors(ExecutionWorkload &workload)
260 {
261     bool is_valid = true;
262     std::for_each(std::begin(workload.outputs), std::end(workload.outputs), [&](Tensor * output_tensor)
263     {
264         bool valid_output = (output_tensor != nullptr) && output_tensor->call_accessor();
265         is_valid          = is_valid && valid_output;
266     });
267 
268     sync_backends();
269 
270     return is_valid;
271 }
272 } // namespace detail
273 } // namespace graph
274 } // namespace arm_compute
275