xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 #include <armnnUtils/Permute.hpp>
10 #include <ResolveType.hpp>
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 class ConvertConstPermuteLayersToConstLayers
18 {
19 public:
Run(Graph & graph,InputSlot & connection) const20     void Run(Graph& graph, InputSlot& connection) const
21     {
22         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
23         Layer& child = connection.GetOwningLayer();
24 
25         ARMNN_ASSERT(base.GetType() == LayerType::Constant);
26         ARMNN_ASSERT(child.GetType() == LayerType::Permute);
27 
28         if (base.GetDataType() == child.GetDataType())
29         {
30             switch (base.GetDataType())
31             {
32                 case DataType::Float16:
33                     ReplaceConstPermuteLayer<DataType::Float16>(graph,
34                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
35                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
36                     break;
37                 case DataType::Float32:
38                     ReplaceConstPermuteLayer<DataType::Float32>(graph,
39                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
40                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
41                     break;
42                 case DataType::QAsymmU8:
43                     ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
44                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
45                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
46                     break;
47                 case DataType::Signed32:
48                     ReplaceConstPermuteLayer<DataType::Signed32>(graph,
49                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
50                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
51                     break;
52                 case DataType::QSymmS16:
53                     ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
54                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
55                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
56                     break;
57                 case DataType::QSymmS8:
58                     ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
59                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
60                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
61                     break;
62                 case DataType::QAsymmS8:
63                     ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
64                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
65                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
66                     break;
67                 case DataType::BFloat16:
68                     ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
69                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
70                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
71                     break;
72                 case DataType::Signed64:
73                     ReplaceConstPermuteLayer<DataType::Signed64>(graph,
74                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
75                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
76                     break;
77                 case DataType::Boolean:
78                     ReplaceConstPermuteLayer<DataType::Boolean>(graph,
79                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
80                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
81                     break;
82             }
83         }
84     }
85 protected:
86     ConvertConstPermuteLayersToConstLayers()  = default;
87     ~ConvertConstPermuteLayersToConstLayers() = default;
88 private:
89     template<armnn::DataType ArmnnType,
90              typename T = armnn::ResolveType<ArmnnType>>
ReplaceConstPermuteLayer(Graph & graph,ConstantLayer * constantLayer,PermuteLayer * permuteLayer)91     static void ReplaceConstPermuteLayer(Graph& graph,
92                                          ConstantLayer* constantLayer,
93                                          PermuteLayer* permuteLayer)
94     {
95         IgnoreUnused(graph);
96         /**
97          * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
98          * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
99          * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
100          */
101         TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
102         std::vector<T> newValues(outputPermuteInfo.GetNumElements());
103         armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
104                             constantLayer->m_LayerOutput->Map(true), newValues.data(),
105                             GetDataTypeSize(outputPermuteInfo.GetDataType()));
106 
107         TensorInfo newInfo = outputPermuteInfo;
108         newInfo.SetConstant(true);
109         ConstTensor newInput(newInfo, newValues);
110         constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
111 
112         // Moves connections in permute output to the constant layer.
113         // Permute layer will be removed if left unconnected.
114         permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
115 
116         // Updating the output tensor
117         constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
118         ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
119     }
120 };
121 
122 using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer,
123                                                         PermuteLayer,
124                                                         ConvertConstPermuteLayersToConstLayers>;
125 
126 } // namespace optimizations
127 } // namespace armnn