xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/ConvertConstants.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "Optimization.hpp"
9 
10 #include <armnnUtils/FloatingPointConverter.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 #include <Half.hpp>
15 
16 namespace armnn
17 {
18 namespace optimizations
19 {
20 
21 struct Float16ToFloat32
22 {
Funcarmnn::optimizations::Float16ToFloat3223     static void Func(std::shared_ptr<ConstTensorHandle>& handle)
24     {
25         const TensorInfo& info = handle->GetTensorInfo();
26 
27         if (info.GetDataType() == DataType::Float16)
28         {
29             std::vector<float> newValues(info.GetNumElements());
30 
31             armnnUtils::FloatingPointConverter::ConvertFloat16To32(handle->GetConstTensor<Half>(),
32                                                                    info.GetNumElements(),
33                                                                    newValues.data());
34 
35             TensorInfo newInfo(info.GetShape(), DataType::Float32, 0.0f, 0, true);
36             ConstTensor newInput(newInfo, newValues);
37             handle.reset(new ScopedTensorHandle(newInput));
38         }
39     }
40 };
41 
42 struct Float32ToFloat16
43 {
Funcarmnn::optimizations::Float32ToFloat1644     static void Func(std::shared_ptr<ConstTensorHandle>& handle)
45     {
46         const TensorInfo& info = handle->GetTensorInfo();
47 
48         if (info.GetDataType() == DataType::Float32)
49         {
50             std::vector<Half> newValues(info.GetNumElements());
51 
52             armnnUtils::FloatingPointConverter::ConvertFloat32To16(handle->GetConstTensor<float>(),
53                                                                    info.GetNumElements(),
54                                                                    newValues.data());
55 
56             TensorInfo newInfo(info.GetShape(), DataType::Float16, 0.0f, 0, true);
57             ConstTensor newInput(newInfo, newValues);
58             handle.reset(new ScopedTensorHandle(newInput));
59         }
60     }
61 };
62 
63 template<typename Converter, typename Predicate>
64 class ConvertConstants : public Optimization
65 {
66 public:
67     ConvertConstants() = default;
68     ConvertConstants(const ConvertConstants&) = default;
69     virtual ~ConvertConstants() = default;
70 
Run(Graph & graph,Layer & layer) const71     void Run(Graph& graph, Layer& layer) const override
72     {
73         IgnoreUnused(graph);
74         if (Predicate::Test(layer))
75         {
76             layer.OperateOnConstantTensors(Converter::Func);
77         }
78     }
79 protected:
80 };
81 
82 struct IsFloat32Layer
83 {
Testarmnn::optimizations::IsFloat32Layer84     static bool Test(const Layer& layer)
85     {
86         return layer.GetDataType() == DataType::Float32;
87     }
88 };
89 
90 struct IsFloat16Layer
91 {
Testarmnn::optimizations::IsFloat16Layer92     static bool Test(const Layer& layer)
93     {
94         return layer.GetDataType() == DataType::Float16;
95     }
96 };
97 
98 using ConvertConstantsHalfToFloat = ConvertConstants<Float16ToFloat32, IsFloat32Layer>;
99 using ConvertConstantsFloatToHalf = ConvertConstants<Float32ToFloat16, IsFloat16Layer>;
100 
101 } //namespace optimizations
102 } //namespace armnn
103