1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 #include <armnn/utility/IgnoreUnused.hpp> 10 #include <armnn/utility/PolymorphicDowncast.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 template <typename PermuteType> 18 class OptimizeInversePermutesImpl 19 { 20 public: 21 /// Run for every connection between a base PermuteLayer and a child PermuteLayer. 22 /// Bypasses both layers for that connection if one is the inverse of the other. Run(Graph & graph,InputSlot & connection) const23 void Run(Graph& graph, InputSlot& connection) const 24 { 25 IgnoreUnused(graph); 26 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 27 auto child = PolymorphicDowncast<PermuteType*>(&connection.GetOwningLayer()); 28 29 if (child->IsInverse(*PolymorphicDowncast<PermuteType*>(&base))) 30 { 31 // Bypass both layers. Child will be removed as it's left unconnected. 32 // Base layer will be removed if left unconnected. 33 child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot()); 34 } 35 } 36 37 protected: 38 OptimizeInversePermutesImpl() = default; 39 ~OptimizeInversePermutesImpl() = default; 40 }; 41 42 using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer, 43 OptimizeInversePermutesImpl<PermuteLayer>>; 44 using OptimizeInverseTransposes = OptimizeForConnection<TransposeLayer, TransposeLayer, 45 OptimizeInversePermutesImpl<TransposeLayer>>; 46 47 } // namespace optimizations 48 } // namespace armnn 49