1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 namespace armnn 10 { 11 namespace optimizations 12 { 13 14 class PermuteAsReshapeImpl 15 { 16 public: 17 /// Run for every PermuteLayer. Replaces it with a ReshapeLayer if they are equivalent. Run(Graph & graph,PermuteLayer & permute) const18 void Run(Graph& graph, PermuteLayer& permute) const 19 { 20 if (IsReshape(permute)) 21 { 22 const TensorInfo& outInfo = permute.GetOutputHandler().GetTensorInfo(); 23 24 const std::string name = std::string("as_reshape-") + permute.GetName(); 25 const ReshapeDescriptor descriptor{outInfo.GetShape()}; 26 // Inserts NewLayer so layers don't need to be re-sorted. 27 auto reshape = graph.InsertNewLayer<ReshapeLayer>(permute.GetInputSlot(0), descriptor, name.c_str()); 28 29 // Bypass permute. It will be deleted since it's left unconnected. 30 permute.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); 31 } 32 } 33 34 protected: 35 PermuteAsReshapeImpl() = default; 36 ~PermuteAsReshapeImpl() = default; 37 38 private: IsReshape(const PermuteLayer & layer)39 static bool IsReshape(const PermuteLayer& layer) 40 { 41 const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); 42 const PermutationVector& permutation = layer.GetPermutation(); 43 44 const unsigned int numDimensions = permutation.GetSize(); 45 46 unsigned int lastGtOne = 0; 47 while ((lastGtOne < numDimensions) && (outShape[(permutation[lastGtOne])] == 1U)) 48 { 49 ++lastGtOne; 50 } 51 52 bool isReshape = true; 53 for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) 54 { 55 if (outShape[permutation[i]] > 1U) 56 { 57 isReshape = permutation[lastGtOne] < permutation[i]; 58 lastGtOne = i; 59 } 60 } 61 62 return isReshape; 63 } 64 }; 65 66 using PermuteAsReshape = OptimizeForType<PermuteLayer, PermuteAsReshapeImpl>; 67 68 } // namespace optimizations 69 } // namespace armnn 70