xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 #include "NetworkUtils.hpp"
9 
10 namespace armnn
11 {
12 namespace optimizations
13 {
14 
15 class ConvertFp32NetworkToFp16Impl
16 {
17 public:
Run(Graph & graph,Layer & layer) const18     void Run(Graph& graph, Layer& layer) const
19     {
20         if(layer.GetType() == LayerType::Input)
21         {
22             // if the outputs of this layer are DataType::Float32
23             // add a ConvertFloat32ToFloat16 layer after each of the outputs
24             if (layer.GetDataType() == DataType::Float32)
25             {
26                 InsertConvertFp32ToFp16LayersAfter(graph, layer);
27             }
28         }
29         else if (layer.GetType() == LayerType::Output)
30         {
31             // For DetectionPostProcess Layer output is always Float32 regardless of input type
32             Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
33             if (connectedLayer.GetType() != LayerType::DetectionPostProcess)
34             {
35                 // if the inputs of this layer are DataType::Float32
36                 // add a ConvertFloat16ToFloat32 layer before each of the inputs
37                 if (layer.GetDataType() == DataType::Float32)
38                 {
39                     // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false
40                     // here, otherwise it will expect the inputs to be DataType::Float16
41                     InsertConvertFp16ToFp32LayersBefore(graph, layer, false);
42                 }
43             }
44         }
45         else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)
46         {
47             // if the inputs/outputs of this layer are DataType::Float32
48             // change the data type for all inputs and outputs to DataType::Float16
49             for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input)
50             {
51                 // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection
52                 // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer
53                 Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer();
54                 if (base.GetType() != LayerType::Input)
55                 {
56                     TensorInfo convertInfo = input->GetConnection()->GetTensorInfo();
57                     if (convertInfo.GetDataType() == DataType::Float32)
58                     {
59                         convertInfo.SetDataType(DataType::Float16);
60                         input->GetConnection()->SetTensorInfo(convertInfo);
61                     }
62                 }
63             }
64 
65             // For DetectionPostProcess Layer output is always Float32 regardless of input type
66             if (layer.GetType() != LayerType::DetectionPostProcess)
67             {
68                 // change outputs to DataType::Float16
69                 for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
70                 {
71                     TensorInfo convertInfo = output->GetTensorInfo();
72                     if (convertInfo.GetDataType() == DataType::Float32)
73                     {
74                         convertInfo.SetDataType(DataType::Float16);
75                         output->SetTensorInfo(convertInfo);
76                     }
77                 }
78             }
79         }
80     }
81 
82 protected:
83     ConvertFp32NetworkToFp16Impl() = default;
84     ~ConvertFp32NetworkToFp16Impl() = default;
85 };
86 
87 using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>;
88 
89 } // namespace optimizations
90 } // namespace armnn
91