1 // 2 // Copyright © 2019 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 namespace armnn 10 { 11 namespace optimizations 12 { 13 14 /// Replaces Permute leading into BatchToSpace with a DepthToSpace 15 /// in the case where the Permute swaps the batch and channels dimensions 16 /// such that the replacement is valid. 17 template <typename PermuteType> 18 class PermuteAndBatchToSpaceAsDepthToSpaceImpl 19 { 20 public: Run(Graph & graph,InputSlot & connection) const21 void Run(Graph& graph, InputSlot& connection) const 22 { 23 // Validate base layer (the Permute) is compatible 24 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 25 ARMNN_ASSERT(base.GetType() == LayerType::Permute || base.GetType() == LayerType::Transpose); 26 const TensorInfo& inputInfo = base.GetInputSlot(0).GetConnection()->GetTensorInfo(); 27 const TensorInfo& intermediateInfo = base.GetOutputSlot(0).GetTensorInfo(); 28 if (intermediateInfo.GetNumDimensions() != 4) 29 { 30 // Must be 4D, otherwise the below checks do not make sense 31 return; 32 } 33 if (!static_cast<PermuteType&>(base).GetParameters().m_DimMappings.IsEqual(PermutationVector{ 3, 1, 2, 0 })) 34 { 35 // Must swap batch and channels dimensions, otherwise it is not the (original) channels dimension 36 // that is being decomposed. 37 return; 38 } 39 40 // Validate child layer (the BatchToSpace) is compatible 41 Layer& child = connection.GetOwningLayer(); 42 ARMNN_ASSERT(child.GetType() == LayerType::BatchToSpaceNd); 43 const TensorInfo& outputInfo = child.GetOutputSlot(0).GetTensorInfo(); 44 const BatchToSpaceNdDescriptor& batchToSpaceDesc = static_cast<BatchToSpaceNdLayer&>(child).GetParameters(); 45 if (batchToSpaceDesc.m_DataLayout != DataLayout::NHWC) 46 { 47 // The rest of this function assumes NHWC, although in future this restriction could be lifted. 48 return; 49 } 50 if (batchToSpaceDesc.m_Crops != std::vector<std::pair<unsigned int, unsigned int>>{ { 0, 0 }, { 0, 0 } }) 51 { 52 // Cropping is not supported in DepthToSpace 53 return; 54 } 55 if (batchToSpaceDesc.m_BlockShape.size() != 2 || 56 batchToSpaceDesc.m_BlockShape[0] != batchToSpaceDesc.m_BlockShape[1]) 57 { 58 // Asymmetric or non-2D block sizes are not supported by DepthToSpace 59 return; 60 } 61 uint32_t blockSize = batchToSpaceDesc.m_BlockShape[0]; 62 if (outputInfo.GetShape()[0] != 1 || outputInfo.GetShape()[3] != 1) 63 { 64 // The final output must have 1 batch and 1 channel because these dimensions will be swapped around 65 // once we make the substitution, and it needs to be equivalent. 66 return; 67 } 68 69 // Validate the intermediate tensor quantization params. 70 // These must be identical to either the input or output quantization params, otherwise the intermediate tensor 71 // may not have sufficient range/precision to preserve the values. 72 // This would mean that once we perform the substitution this loss of precision will no longer occur, 73 // so we would have changed the meaning of the network. 74 bool isIntermediateQuantParamsSameAsInput = 75 intermediateInfo.GetQuantizationScale() == inputInfo.GetQuantizationScale() && 76 intermediateInfo.GetQuantizationOffset() == inputInfo.GetQuantizationOffset(); 77 bool isIntermediateQuantParamsSameAsOutput = 78 intermediateInfo.GetQuantizationScale() == outputInfo.GetQuantizationScale() && 79 intermediateInfo.GetQuantizationOffset() == outputInfo.GetQuantizationOffset(); 80 if (!isIntermediateQuantParamsSameAsInput && !isIntermediateQuantParamsSameAsOutput) 81 { 82 return; 83 } 84 85 // Insert equivalent DepthToSpace layer 86 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); 87 88 // Inserts equivalent reshape before base layer. 89 const DepthToSpaceDescriptor depthToSpaceDesc(blockSize, DataLayout::NHWC); 90 auto& depthToSpace = *graph.InsertNewLayer<DepthToSpaceLayer>(base.GetInputSlot(0), 91 depthToSpaceDesc, 92 name.c_str()); 93 94 // Moves connections from child output to new layer. 95 // Child layer will be removed as it's left unconnected. 96 // Base layer will be removed if left unconnected. 97 child.GetOutputSlot().MoveAllConnections(depthToSpace.GetOutputSlot()); 98 } 99 }; 100 101 using PermuteAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<PermuteLayer, BatchToSpaceNdLayer, 102 PermuteAndBatchToSpaceAsDepthToSpaceImpl<PermuteLayer>>; 103 using TransposeAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<TransposeLayer, BatchToSpaceNdLayer, 104 PermuteAndBatchToSpaceAsDepthToSpaceImpl<TransposeLayer>>; 105 } // namespace optimizations 106 } // namespace armnn 107