xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/PermuteAsReshape.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
14 class PermuteAsReshapeImpl
15 {
16 public:
17     /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent.
Run(Graph & graph,PermuteLayer & permute) const18     void Run(Graph& graph, PermuteLayer& permute) const
19     {
20         if (IsReshape(permute))
21         {
22             const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo();
23 
24             const std::string name = std::string("as_reshape-") + permute.GetName();
25             const ReshapeDescriptor descriptor{outInfo.GetShape()};
26             // Inserts NewLayer so layers don't need to be re-sorted.
27             auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str());
28 
29             // Bypass permute. It will be deleted since it's left unconnected.
30             permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31         }
32     }
33 
34 protected:
35     PermuteAsReshapeImpl() = default;
36     ~PermuteAsReshapeImpl() = default;
37 
38 private:
IsReshape(const PermuteLayer & layer)39     static bool IsReshape(const PermuteLayer& layer)
40     {
41         const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42         const PermutationVector& permutation = layer.GetPermutation();
43 
44         const unsigned int numDimensions = permutation.GetSize();
45 
46         unsigned int lastGtOne = 0;
47         while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U))
48         {
49             ++lastGtOne;
50         }
51 
52         bool isReshape = true;
53         for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
54         {
55             if (outShape[permutation[i]] > 1U)
56             {
57                 isReshape = permutation[lastGtOne] < permutation[i];
58                 lastGtOne = i;
59             }
60         }
61 
62         return isReshape;
63     }
64 };
65 
66 using PermuteAsReshape = OptimizeForType<PermuteLayer, PermuteAsReshapeImpl>;
67 
68 } // namespace optimizations
69 } // namespace armnn
70