xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/Optimization.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Graph.hpp"
8 #include "LayersFwd.hpp"
9 
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 namespace armnn
13 {
14 
15 class Optimization
16 {
17 public:
18     Optimization() = default;
19     virtual ~Optimization() = default;
20     virtual void Run(Graph& graph, Layer& base) const = 0;
21 protected:
22 };
23 
24 // Wrappers
25 // The implementation of the following wrappers make use of the CRTP C++ idiom
26 // (curiously recurring template pattern).
27 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern
28 
29 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType.
30 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
31 ///   after applying each optimization.
32 template <typename BaseType, typename Wrapped>
33 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped
34 {
35 public:
36     using Wrapped::Wrapped;
37 
Run(Graph & graph,Layer & base) const38     void Run(Graph& graph, Layer& base) const override
39     {
40         if (base.GetType() == LayerEnumOf<BaseType>())
41         {
42             Wrapped::Run(graph, *PolymorphicDowncast<BaseType*>(&base));
43         }
44     }
45 
46 protected:
47     ~OptimizeForTypeImpl() = default;
48 };
49 
50 /// Specialization that calls Wrapped::Run() for any layer type.
51 template <typename Wrapped>
52 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped
53 {
54 public:
55     using Wrapped::Wrapped;
56 
Run(Graph & graph,Layer & base) const57     void Run(Graph& graph, Layer& base) const override
58     {
59         Wrapped::Run(graph, base);
60     }
61 
62 protected:
63     ~OptimizeForTypeImpl() = default;
64 };
65 
66 template <typename BaseType, typename Wrapped>
67 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped>
68 {
69 public:
70     using OptimizeForTypeImpl<BaseType, Wrapped>::OptimizeForTypeImpl;
71 };
72 
73 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
74 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
75 ///   after applying each optimization.
76 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
77 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
78 template <typename BaseType, typename ChildType, typename Wrapped>
79 class OptimizeForConnectionImpl : public Wrapped
80 {
81 public:
82     using Wrapped::Wrapped;
83 
Run(Graph & graph,BaseType & base) const84     void Run(Graph& graph, BaseType& base) const
85     {
86         for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
87         {
88             for (auto&& childInput : output->GetConnections())
89             {
90                 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
91                 {
92                     Wrapped::Run(graph, *childInput);
93                 }
94             }
95 
96             // Removes unconnected children.
97             for (unsigned int i = 0; i < output->GetNumConnections();)
98             {
99                 Layer* child = &output->GetConnection(i)->GetOwningLayer();
100 
101                 if (child->IsOutputUnconnected())
102                 {
103                     graph.EraseLayer(child);
104                 }
105                 else
106                 {
107                     ++i;
108                 }
109             }
110         }
111     }
112 
113 protected:
114     ~OptimizeForConnectionImpl() = default;
115 };
116 
117 template <typename BaseType, typename ChildType, typename Wrapped>
118 class OptimizeForConnection final
119     : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>
120 {
121 public:
122     using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
123 };
124 
125 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType.
126 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected
127 ///   after applying each optimization.
128 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones.
129 /// - Children layers are removed if left unconnected after applying the wrapped optimization.
130 template <typename BaseType, typename ChildType, typename Wrapped>
131 class OptimizeForExclusiveConnectionImpl : public Wrapped
132 {
133 public:
134     using Wrapped::Wrapped;
135 
Run(Graph & graph,BaseType & base) const136     void Run(Graph& graph, BaseType& base) const
137     {
138         for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
139         {
140             if (output->GetNumConnections() == 1)
141             {
142                 for (auto&& childInput : output->GetConnections())
143                 {
144                     if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>())
145                     {
146                         Wrapped::Run(graph, *childInput);
147                     }
148                 }
149 
150                 // Removes unconnected children.
151                 for (unsigned int i = 0; i < output->GetNumConnections();)
152                 {
153                     Layer* child = &output->GetConnection(i)->GetOwningLayer();
154 
155                     if (child->IsOutputUnconnected())
156                     {
157                         graph.EraseLayer(child);
158                     }
159                     else
160                     {
161                         ++i;
162                     }
163                 }
164             }
165         }
166     }
167 
168 protected:
169     ~OptimizeForExclusiveConnectionImpl() = default;
170 };
171 
172 template <typename BaseType, typename ChildType, typename Wrapped>
173 class OptimizeForExclusiveConnection final
174         : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>
175 {
176 public:
177     using OptimizeForTypeImpl<BaseType,
178                               OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl;
179 };
180 
181 } // namespace armnn
182