xref: /aosp_15_r20/external/ComputeLibrary/src/graph/algorithms/TopologicalSort.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph/algorithms/TopologicalSort.h"
25 
26 #include "arm_compute/graph/Graph.h"
27 
28 #include "support/Iterable.h"
29 
30 #include <list>
31 #include <stack>
32 
33 namespace arm_compute
34 {
35 namespace graph
36 {
37 namespace detail
38 {
39 /** Checks if all the input dependencies of a node have been visited
40  *
41  * @param[in] node    Node to check
42  * @param[in] visited Vector that contains the visited information
43  *
44  * @return True if all inputs dependencies have been visited else false
45  */
all_inputs_are_visited(const INode * node,const std::vector<bool> & visited)46 inline bool all_inputs_are_visited(const INode *node, const std::vector<bool> &visited)
47 {
48     ARM_COMPUTE_ERROR_ON(node == nullptr);
49     const Graph *graph = node->graph();
50     ARM_COMPUTE_ERROR_ON(graph == nullptr);
51 
52     bool are_all_visited = true;
53     for(const auto &input_edge_id : node->input_edges())
54     {
55         if(input_edge_id != EmptyNodeID)
56         {
57             const Edge *input_edge = graph->edge(input_edge_id);
58             ARM_COMPUTE_ERROR_ON(input_edge == nullptr);
59             ARM_COMPUTE_ERROR_ON(input_edge->producer() == nullptr);
60             if(!visited[input_edge->producer_id()])
61             {
62                 are_all_visited = false;
63                 break;
64             }
65         }
66     }
67 
68     return are_all_visited;
69 }
70 } // namespace detail
71 
bfs(Graph & g)72 std::vector<NodeID> bfs(Graph &g)
73 {
74     std::vector<NodeID> bfs_order_vector;
75 
76     // Created visited vector
77     std::vector<bool> visited(g.nodes().size(), false);
78 
79     // Create BFS queue
80     std::list<NodeID> queue;
81 
82     // Push inputs and mark as visited
83     for(auto &input : g.nodes(NodeType::Input))
84     {
85         if(input != EmptyNodeID)
86         {
87             visited[input] = true;
88             queue.push_back(input);
89         }
90     }
91 
92     // Push const nodes and mark as visited
93     for(auto &const_node : g.nodes(NodeType::Const))
94     {
95         if(const_node != EmptyNodeID)
96         {
97             visited[const_node] = true;
98             queue.push_back(const_node);
99         }
100     }
101 
102     // Iterate over vector and edges
103     while(!queue.empty())
104     {
105         // Dequeue a node from queue and process
106         NodeID n = queue.front();
107         bfs_order_vector.push_back(n);
108         queue.pop_front();
109 
110         const INode *node = g.node(n);
111         ARM_COMPUTE_ERROR_ON(node == nullptr);
112         for(const auto &eid : node->output_edges())
113         {
114             const Edge *e = g.edge(eid);
115             ARM_COMPUTE_ERROR_ON(e == nullptr);
116             if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
117             {
118                 visited[e->consumer_id()] = true;
119                 queue.push_back(e->consumer_id());
120             }
121         }
122     }
123 
124     return bfs_order_vector;
125 }
126 
dfs(Graph & g)127 std::vector<NodeID> dfs(Graph &g)
128 {
129     std::vector<NodeID> dfs_order_vector;
130 
131     // Created visited vector
132     std::vector<bool> visited(g.nodes().size(), false);
133 
134     // Create DFS stack
135     std::stack<NodeID> stack;
136 
137     // Push inputs and mark as visited
138     for(auto &input : g.nodes(NodeType::Input))
139     {
140         if(input != EmptyNodeID)
141         {
142             visited[input] = true;
143             stack.push(input);
144         }
145     }
146 
147     // Push const nodes and mark as visited
148     for(auto &const_node : g.nodes(NodeType::Const))
149     {
150         if(const_node != EmptyNodeID)
151         {
152             visited[const_node] = true;
153             stack.push(const_node);
154         }
155     }
156 
157     // Iterate over vector and edges
158     while(!stack.empty())
159     {
160         // Pop a node from stack and process
161         NodeID n = stack.top();
162         dfs_order_vector.push_back(n);
163         stack.pop();
164 
165         // Mark node as visited
166         if(!visited[n])
167         {
168             visited[n] = true;
169         }
170 
171         const INode *node = g.node(n);
172         ARM_COMPUTE_ERROR_ON(node == nullptr);
173         // Reverse iterate to push branches from right to left and pop on the opposite order
174         for(const auto &eid : arm_compute::utils::iterable::reverse_iterate(node->output_edges()))
175         {
176             const Edge *e = g.edge(eid);
177             ARM_COMPUTE_ERROR_ON(e == nullptr);
178             if(!visited[e->consumer_id()] && detail::all_inputs_are_visited(e->consumer(), visited))
179             {
180                 stack.push(e->consumer_id());
181             }
182         }
183     }
184 
185     return dfs_order_vector;
186 }
187 } // namespace graph
188 } // namespace arm_compute
189