1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 #include "NetworkUtils.hpp" 9 10 namespace armnn 11 { 12 namespace optimizations 13 { 14 15 class ConvertFp32NetworkToFp16Impl 16 { 17 public: Run(Graph & graph,Layer & layer) const18 void Run(Graph& graph, Layer& layer) const 19 { 20 if(layer.GetType() == LayerType::Input) 21 { 22 // if the outputs of this layer are DataType::Float32 23 // add a ConvertFloat32ToFloat16 layer after each of the outputs 24 if (layer.GetDataType() == DataType::Float32) 25 { 26 InsertConvertFp32ToFp16LayersAfter(graph, layer); 27 } 28 } 29 else if (layer.GetType() == LayerType::Output) 30 { 31 // For DetectionPostProcess Layer output is always Float32 regardless of input type 32 Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer(); 33 if (connectedLayer.GetType() != LayerType::DetectionPostProcess) 34 { 35 // if the inputs of this layer are DataType::Float32 36 // add a ConvertFloat16ToFloat32 layer before each of the inputs 37 if (layer.GetDataType() == DataType::Float32) 38 { 39 // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false 40 // here, otherwise it will expect the inputs to be DataType::Float16 41 InsertConvertFp16ToFp32LayersBefore(graph, layer, false); 42 } 43 } 44 } 45 else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) 46 { 47 // if the inputs/outputs of this layer are DataType::Float32 48 // change the data type for all inputs and outputs to DataType::Float16 49 for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input) 50 { 51 // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection 52 // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer 53 Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer(); 54 if (base.GetType() != LayerType::Input) 55 { 56 TensorInfo convertInfo = input->GetConnection()->GetTensorInfo(); 57 if (convertInfo.GetDataType() == DataType::Float32) 58 { 59 convertInfo.SetDataType(DataType::Float16); 60 input->GetConnection()->SetTensorInfo(convertInfo); 61 } 62 } 63 } 64 65 // For DetectionPostProcess Layer output is always Float32 regardless of input type 66 if (layer.GetType() != LayerType::DetectionPostProcess) 67 { 68 // change outputs to DataType::Float16 69 for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) 70 { 71 TensorInfo convertInfo = output->GetTensorInfo(); 72 if (convertInfo.GetDataType() == DataType::Float32) 73 { 74 convertInfo.SetDataType(DataType::Float16); 75 output->SetTensorInfo(convertInfo); 76 } 77 } 78 } 79 } 80 } 81 82 protected: 83 ConvertFp32NetworkToFp16Impl() = default; 84 ~ConvertFp32NetworkToFp16Impl() = default; 85 }; 86 87 using Fp32NetworkToFp16Converter = OptimizeForType<Layer, ConvertFp32NetworkToFp16Impl>; 88 89 } // namespace optimizations 90 } // namespace armnn 91