1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Graph.hpp" 8 #include "LayersFwd.hpp" 9 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 namespace armnn 13 { 14 15 class Optimization 16 { 17 public: 18 Optimization() = default; 19 virtual ~Optimization() = default; 20 virtual void Run(Graph& graph, Layer& base) const = 0; 21 protected: 22 }; 23 24 // Wrappers 25 // The implementation of the following wrappers make use of the CRTP C++ idiom 26 // (curiously recurring template pattern). 27 // For details, see https://en.wikipedia.org/wiki/Curiously_recurring_template_pattern 28 29 /// Wrapper Optimization base class that calls Wrapped::Run() for every layer of type BaseType. 30 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected 31 /// after applying each optimization. 32 template <typename BaseType, typename Wrapped> 33 class OptimizeForTypeImpl : public armnn::Optimization, public Wrapped 34 { 35 public: 36 using Wrapped::Wrapped; 37 Run(Graph & graph,Layer & base) const38 void Run(Graph& graph, Layer& base) const override 39 { 40 if (base.GetType() == LayerEnumOf<BaseType>()) 41 { 42 Wrapped::Run(graph, *PolymorphicDowncast<BaseType*>(&base)); 43 } 44 } 45 46 protected: 47 ~OptimizeForTypeImpl() = default; 48 }; 49 50 /// Specialization that calls Wrapped::Run() for any layer type. 51 template <typename Wrapped> 52 class OptimizeForTypeImpl<Layer, Wrapped> : public armnn::Optimization, public Wrapped 53 { 54 public: 55 using Wrapped::Wrapped; 56 Run(Graph & graph,Layer & base) const57 void Run(Graph& graph, Layer& base) const override 58 { 59 Wrapped::Run(graph, base); 60 } 61 62 protected: 63 ~OptimizeForTypeImpl() = default; 64 }; 65 66 template <typename BaseType, typename Wrapped> 67 class OptimizeForType final : public OptimizeForTypeImpl<BaseType, Wrapped> 68 { 69 public: 70 using OptimizeForTypeImpl<BaseType, Wrapped>::OptimizeForTypeImpl; 71 }; 72 73 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. 74 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected 75 /// after applying each optimization. 76 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. 77 /// - Children layers are removed if left unconnected after applying the wrapped optimization. 78 template <typename BaseType, typename ChildType, typename Wrapped> 79 class OptimizeForConnectionImpl : public Wrapped 80 { 81 public: 82 using Wrapped::Wrapped; 83 Run(Graph & graph,BaseType & base) const84 void Run(Graph& graph, BaseType& base) const 85 { 86 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) 87 { 88 for (auto&& childInput : output->GetConnections()) 89 { 90 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>()) 91 { 92 Wrapped::Run(graph, *childInput); 93 } 94 } 95 96 // Removes unconnected children. 97 for (unsigned int i = 0; i < output->GetNumConnections();) 98 { 99 Layer* child = &output->GetConnection(i)->GetOwningLayer(); 100 101 if (child->IsOutputUnconnected()) 102 { 103 graph.EraseLayer(child); 104 } 105 else 106 { 107 ++i; 108 } 109 } 110 } 111 } 112 113 protected: 114 ~OptimizeForConnectionImpl() = default; 115 }; 116 117 template <typename BaseType, typename ChildType, typename Wrapped> 118 class OptimizeForConnection final 119 : public OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>> 120 { 121 public: 122 using OptimizeForTypeImpl<BaseType, OptimizeForConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; 123 }; 124 125 /// Wrapper Optimization class that calls Wrapped::Run for every connection BaseType -> ChildType. 126 /// - Wrapped class mustn't remove the base layer. The optimizer will remove it if left unconnected 127 /// after applying each optimization. 128 /// - Wrapped class mustn't affect existing connections in the same output. It might add new ones. 129 /// - Children layers are removed if left unconnected after applying the wrapped optimization. 130 template <typename BaseType, typename ChildType, typename Wrapped> 131 class OptimizeForExclusiveConnectionImpl : public Wrapped 132 { 133 public: 134 using Wrapped::Wrapped; 135 Run(Graph & graph,BaseType & base) const136 void Run(Graph& graph, BaseType& base) const 137 { 138 for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output) 139 { 140 if (output->GetNumConnections() == 1) 141 { 142 for (auto&& childInput : output->GetConnections()) 143 { 144 if (childInput->GetOwningLayer().GetType() == LayerEnumOf<ChildType>()) 145 { 146 Wrapped::Run(graph, *childInput); 147 } 148 } 149 150 // Removes unconnected children. 151 for (unsigned int i = 0; i < output->GetNumConnections();) 152 { 153 Layer* child = &output->GetConnection(i)->GetOwningLayer(); 154 155 if (child->IsOutputUnconnected()) 156 { 157 graph.EraseLayer(child); 158 } 159 else 160 { 161 ++i; 162 } 163 } 164 } 165 } 166 } 167 168 protected: 169 ~OptimizeForExclusiveConnectionImpl() = default; 170 }; 171 172 template <typename BaseType, typename ChildType, typename Wrapped> 173 class OptimizeForExclusiveConnection final 174 : public OptimizeForTypeImpl<BaseType, OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>> 175 { 176 public: 177 using OptimizeForTypeImpl<BaseType, 178 OptimizeForExclusiveConnectionImpl<BaseType, ChildType, Wrapped>>::OptimizeForTypeImpl; 179 }; 180 181 } // namespace armnn 182