1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "Optimization.hpp" 9 #include <armnnUtils/Permute.hpp> 10 #include <ResolveType.hpp> 11 12 namespace armnn 13 { 14 namespace optimizations 15 { 16 17 class ConvertConstPermuteLayersToConstLayers 18 { 19 public: Run(Graph & graph,InputSlot & connection) const20 void Run(Graph& graph, InputSlot& connection) const 21 { 22 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 23 Layer& child = connection.GetOwningLayer(); 24 25 ARMNN_ASSERT(base.GetType() == LayerType::Constant); 26 ARMNN_ASSERT(child.GetType() == LayerType::Permute); 27 28 if (base.GetDataType() == child.GetDataType()) 29 { 30 switch (base.GetDataType()) 31 { 32 case DataType::Float16: 33 ReplaceConstPermuteLayer<DataType::Float16>(graph, 34 PolymorphicDowncast<ConstantLayer*>(&base), 35 PolymorphicDowncast<PermuteLayer*>(&child)); 36 break; 37 case DataType::Float32: 38 ReplaceConstPermuteLayer<DataType::Float32>(graph, 39 PolymorphicDowncast<ConstantLayer*>(&base), 40 PolymorphicDowncast<PermuteLayer*>(&child)); 41 break; 42 case DataType::QAsymmU8: 43 ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph, 44 PolymorphicDowncast<ConstantLayer*>(&base), 45 PolymorphicDowncast<PermuteLayer*>(&child)); 46 break; 47 case DataType::Signed32: 48 ReplaceConstPermuteLayer<DataType::Signed32>(graph, 49 PolymorphicDowncast<ConstantLayer*>(&base), 50 PolymorphicDowncast<PermuteLayer*>(&child)); 51 break; 52 case DataType::QSymmS16: 53 ReplaceConstPermuteLayer<DataType::QSymmS16>(graph, 54 PolymorphicDowncast<ConstantLayer*>(&base), 55 PolymorphicDowncast<PermuteLayer*>(&child)); 56 break; 57 case DataType::QSymmS8: 58 ReplaceConstPermuteLayer<DataType::QSymmS8>(graph, 59 PolymorphicDowncast<ConstantLayer*>(&base), 60 PolymorphicDowncast<PermuteLayer*>(&child)); 61 break; 62 case DataType::QAsymmS8: 63 ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph, 64 PolymorphicDowncast<ConstantLayer*>(&base), 65 PolymorphicDowncast<PermuteLayer*>(&child)); 66 break; 67 case DataType::BFloat16: 68 ReplaceConstPermuteLayer<DataType::BFloat16>(graph, 69 PolymorphicDowncast<ConstantLayer*>(&base), 70 PolymorphicDowncast<PermuteLayer*>(&child)); 71 break; 72 case DataType::Signed64: 73 ReplaceConstPermuteLayer<DataType::Signed64>(graph, 74 PolymorphicDowncast<ConstantLayer*>(&base), 75 PolymorphicDowncast<PermuteLayer*>(&child)); 76 break; 77 case DataType::Boolean: 78 ReplaceConstPermuteLayer<DataType::Boolean>(graph, 79 PolymorphicDowncast<ConstantLayer*>(&base), 80 PolymorphicDowncast<PermuteLayer*>(&child)); 81 break; 82 } 83 } 84 } 85 protected: 86 ConvertConstPermuteLayersToConstLayers() = default; 87 ~ConvertConstPermuteLayersToConstLayers() = default; 88 private: 89 template<armnn::DataType ArmnnType, 90 typename T = armnn::ResolveType<ArmnnType>> ReplaceConstPermuteLayer(Graph & graph,ConstantLayer * constantLayer,PermuteLayer * permuteLayer)91 static void ReplaceConstPermuteLayer(Graph& graph, 92 ConstantLayer* constantLayer, 93 PermuteLayer* permuteLayer) 94 { 95 IgnoreUnused(graph); 96 /** 97 * This optimisation is to find situations where a constant set of inputs is being provided to a Permute 98 * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we 99 * want to Permute them once and store them in a Const layer to be used everytime as they will not change. 100 */ 101 TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo(); 102 std::vector<T> newValues(outputPermuteInfo.GetNumElements()); 103 armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(), 104 constantLayer->m_LayerOutput->Map(true), newValues.data(), 105 GetDataTypeSize(outputPermuteInfo.GetDataType())); 106 107 TensorInfo newInfo = outputPermuteInfo; 108 newInfo.SetConstant(true); 109 ConstTensor newInput(newInfo, newValues); 110 constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput)); 111 112 // Moves connections in permute output to the constant layer. 113 // Permute layer will be removed if left unconnected. 114 permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot()); 115 116 // Updating the output tensor 117 constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo); 118 ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true); 119 } 120 }; 121 122 using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer, 123 PermuteLayer, 124 ConvertConstPermuteLayersToConstLayers>; 125 126 } // namespace optimizations 127 } // namespace armnn