1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include "Optimization.hpp" 9 10 #include <armnnUtils/FloatingPointConverter.hpp> 11 #include <armnn/backends/TensorHandle.hpp> 12 #include <armnn/utility/IgnoreUnused.hpp> 13 14 #include <Half.hpp> 15 16 namespace armnn 17 { 18 namespace optimizations 19 { 20 21 struct Float16ToFloat32 22 { Funcarmnn::optimizations::Float16ToFloat3223 static void Func(std::shared_ptr<ConstTensorHandle>& handle) 24 { 25 const TensorInfo& info = handle->GetTensorInfo(); 26 27 if (info.GetDataType() == DataType::Float16) 28 { 29 std::vector<float> newValues(info.GetNumElements()); 30 31 armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetConstTensor<Half>(), 32 info.GetNumElements(), 33 newValues.data()); 34 35 TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true); 36 ConstTensor newInput(newInfo, newValues); 37 handle.reset(new ScopedTensorHandle(newInput)); 38 } 39 } 40 }; 41 42 struct Float32ToFloat16 43 { Funcarmnn::optimizations::Float32ToFloat1644 static void Func(std::shared_ptr<ConstTensorHandle>& handle) 45 { 46 const TensorInfo& info = handle->GetTensorInfo(); 47 48 if (info.GetDataType() == DataType::Float32) 49 { 50 std::vector<Half> newValues(info.GetNumElements()); 51 52 armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(), 53 info.GetNumElements(), 54 newValues.data()); 55 56 TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true); 57 ConstTensor newInput(newInfo, newValues); 58 handle.reset(new ScopedTensorHandle(newInput)); 59 } 60 } 61 }; 62 63 template<typename Converter, typename Predicate> 64 class ConvertConstants : public Optimization 65 { 66 public: 67 ConvertConstants() = default; 68 ConvertConstants(const ConvertConstants&) = default; 69 virtual ~ConvertConstants() = default; 70 Run(Graph & graph,Layer & layer) const71 void Run(Graph& graph, Layer& layer) const override 72 { 73 IgnoreUnused(graph); 74 if (Predicate::Test(layer)) 75 { 76 layer.OperateOnConstantTensors(Converter::Func); 77 } 78 } 79 protected: 80 }; 81 82 struct IsFloat32Layer 83 { Testarmnn::optimizations::IsFloat32Layer84 static bool Test(const Layer& layer) 85 { 86 return layer.GetDataType() == DataType::Float32; 87 } 88 }; 89 90 struct IsFloat16Layer 91 { Testarmnn::optimizations::IsFloat16Layer92 static bool Test(const Layer& layer) 93 { 94 return layer.GetDataType() == DataType::Float16; 95 } 96 }; 97 98 using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>; 99 using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>; 100 101 } //namespace optimizations 102 } //namespace armnn 103