xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/PermuteAndBatchToSpaceAsDepthToSpace.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
14 /// Replaces Permute leading into BatchToSpace with a DepthToSpace
15 /// in the case where the Permute swaps the batch and channels dimensions
16 /// such that the replacement is valid.
17 template <typename PermuteType>
18 class PermuteAndBatchToSpaceAsDepthToSpaceImpl
19 {
20 public:
Run(Graph & graph,InputSlot & connection) const21     void Run(Graph& graph, InputSlot& connection) const
22     {
23         // Validate base layer (the Permute) is compatible
24         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
25         ARMNN_ASSERT(base.GetType() == LayerType::Permute || base.GetType() == LayerType::Transpose);
26         const TensorInfo& inputInfo = base.GetInputSlot(0).GetConnection()->GetTensorInfo();
27         const TensorInfo& intermediateInfo = base.GetOutputSlot(0).GetTensorInfo();
28         if (intermediateInfo.GetNumDimensions() != 4)
29         {
30             // Must be 4D, otherwise the below checks do not make sense
31             return;
32         }
33         if (!static_cast<PermuteType&>(base).GetParameters().m_DimMappings.IsEqual(PermutationVector{ 3, 1, 2, 0 }))
34         {
35             // Must swap batch and channels dimensions, otherwise it is not the (original) channels dimension
36             // that is being decomposed.
37             return;
38         }
39 
40         // Validate child layer (the BatchToSpace) is compatible
41         Layer& child = connection.GetOwningLayer();
42         ARMNN_ASSERT(child.GetType() == LayerType::BatchToSpaceNd);
43         const TensorInfo& outputInfo = child.GetOutputSlot(0).GetTensorInfo();
44         const BatchToSpaceNdDescriptor& batchToSpaceDesc = static_cast<BatchToSpaceNdLayer&>(child).GetParameters();
45         if (batchToSpaceDesc.m_DataLayout != DataLayout::NHWC)
46         {
47             // The rest of this function assumes NHWC, although in future this restriction could be lifted.
48             return;
49         }
50         if (batchToSpaceDesc.m_Crops != std::vector<std::pair<unsigned int, unsigned int>>{ { 0, 0 }, { 0, 0 } })
51         {
52             // Cropping is not supported in DepthToSpace
53             return;
54         }
55         if (batchToSpaceDesc.m_BlockShape.size() != 2 ||
56         batchToSpaceDesc.m_BlockShape[0] != batchToSpaceDesc.m_BlockShape[1])
57         {
58             // Asymmetric or non-2D block sizes are not supported by DepthToSpace
59             return;
60         }
61         uint32_t blockSize = batchToSpaceDesc.m_BlockShape[0];
62         if (outputInfo.GetShape()[0] != 1 || outputInfo.GetShape()[3] != 1)
63         {
64             // The final output must have 1 batch and 1 channel because these dimensions will be swapped around
65             // once we make the substitution, and it needs to be equivalent.
66             return;
67         }
68 
69         // Validate the intermediate tensor quantization params.
70         // These must be identical to either the input or output quantization params, otherwise the intermediate tensor
71         // may not have sufficient range/precision to preserve the values.
72         // This would mean that once we perform the substitution this loss of precision will no longer occur,
73         // so we would have changed the meaning of the network.
74         bool isIntermediateQuantParamsSameAsInput =
75                 intermediateInfo.GetQuantizationScale() == inputInfo.GetQuantizationScale() &&
76                 intermediateInfo.GetQuantizationOffset() == inputInfo.GetQuantizationOffset();
77         bool isIntermediateQuantParamsSameAsOutput =
78                 intermediateInfo.GetQuantizationScale() == outputInfo.GetQuantizationScale() &&
79                 intermediateInfo.GetQuantizationOffset() == outputInfo.GetQuantizationOffset();
80         if (!isIntermediateQuantParamsSameAsInput && !isIntermediateQuantParamsSameAsOutput)
81         {
82             return;
83         }
84 
85         // Insert equivalent DepthToSpace layer
86         const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
87 
88         // Inserts equivalent reshape before base layer.
89         const DepthToSpaceDescriptor depthToSpaceDesc(blockSize, DataLayout::NHWC);
90         auto& depthToSpace = *graph.InsertNewLayer<DepthToSpaceLayer>(base.GetInputSlot(0),
91                                                                       depthToSpaceDesc,
92                                                                       name.c_str());
93 
94         // Moves connections from child output to new layer.
95         // Child layer will be removed as it's left unconnected.
96         // Base layer will be removed if left unconnected.
97         child.GetOutputSlot().MoveAllConnections(depthToSpace.GetOutputSlot());
98     }
99 };
100 
101 using PermuteAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<PermuteLayer, BatchToSpaceNdLayer,
102     PermuteAndBatchToSpaceAsDepthToSpaceImpl<PermuteLayer>>;
103 using TransposeAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<TransposeLayer, BatchToSpaceNdLayer,
104     PermuteAndBatchToSpaceAsDepthToSpaceImpl<TransposeLayer>>;
105 }    // namespace optimizations
106 }    // namespace armnn
107