xref: /aosp_15_r20/external/armnn/src/armnn/Optimizer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "Optimizer.hpp"
6 #include "Observable.hpp"
7 #include "optimizations/All.hpp"
8 
9 namespace armnn
10 {
11 
Optimizer()12 Optimizer::Optimizer()
13 {
14 }
15 
Pass(Graph & graph,const Optimizations & optimizations)16 void Optimizer::Pass(Graph& graph, const Optimizations& optimizations)
17 {
18     ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "Optimizer_Pass");
19     // Create observables to observe changes to the graph
20     AddedLayerObservable addedLayerObservable(graph);
21     ErasedLayerNamesObservable erasedLayerNamesObservable(graph);
22 
23     bool graphNeedsSorting = false;
24     auto it = graph.TopologicalSort().end();
25 
26     // Calls TopologicalSort() for every iteration to re-order the list in case layers were added/removed.
27     while (it != graph.TopologicalSort().begin())
28     {
29         --it;
30         for (auto&& optimization : optimizations)
31         {
32             ARMNN_ASSERT(*it);
33             optimization->Run(graph, **it);
34 
35             if ((*it)->IsOutputUnconnected())
36             {
37                 auto next = std::next(graph.GetPosInGraph(**it));
38                 graph.EraseLayer(it);
39                 it = next;
40                 graphNeedsSorting = true;
41             }
42 
43             // Add the names of erased layers as related layers to the new added layers
44             for (auto& erasedLayerName : erasedLayerNamesObservable)
45             {
46                 for (auto& addedLayer : addedLayerObservable)
47                 {
48                     addedLayer->AddRelatedLayerName(erasedLayerName);
49                 }
50             }
51 
52             erasedLayerNamesObservable.Clear();
53             addedLayerObservable.Clear();
54 
55             if (graphNeedsSorting)
56             {
57                 graphNeedsSorting = false;
58                 break;
59             }
60         }
61     }
62 }
63 
64 } // namespace armnn
65