xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/OptimizeInversePermutes.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 #include <armnn/utility/IgnoreUnused.hpp>
10 #include <armnn/utility/PolymorphicDowncast.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 template <typename PermuteType>
18 class OptimizeInversePermutesImpl
19 {
20 public:
21     /// Run for every connection between a base PermuteLayer and a child PermuteLayer.
22     /// Bypasses both layers for that connection if one is the inverse of the other.
Run(Graph & graph,InputSlot & connection) const23     void Run(Graph& graph, InputSlot& connection) const
24     {
25         IgnoreUnused(graph);
26         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
27         auto child = PolymorphicDowncast<PermuteType*>(&connection.GetOwningLayer());
28 
29         if (child->IsInverse(*PolymorphicDowncast<PermuteType*>(&base)))
30         {
31             // Bypass both layers. Child will be removed as it's left unconnected.
32             // Base layer will be removed if left unconnected.
33             child->GetOutputSlot().MoveAllConnections(*base.GetInputSlot(0).GetConnectedOutputSlot());
34         }
35     }
36 
37 protected:
38     OptimizeInversePermutesImpl() = default;
39     ~OptimizeInversePermutesImpl() = default;
40 };
41 
42 using OptimizeInversePermutes = OptimizeForConnection<PermuteLayer, PermuteLayer,
43     OptimizeInversePermutesImpl<PermuteLayer>>;
44 using OptimizeInverseTransposes = OptimizeForConnection<TransposeLayer, TransposeLayer,
45     OptimizeInversePermutesImpl<TransposeLayer>>;
46 
47 } // namespace optimizations
48 } // namespace armnn
49