xref: /aosp_15_r20/external/armnn/src/armnn/test/optimizations/ConvertConstantsHalfToFloatTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations;
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ConvertConstantsHalfToFloatTest")
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker     armnn::Graph graph;
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 1, 1, 2 }, armnn::DataType::Float32);
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     // Create the half precision input data
23*89c4ff92SAndroid Build Coastguard Worker     unsigned int dims[] = { 4, 1, 1, 1 };
24*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> convWeightsData{ 1.f, 2.f, 3.f, 4.f };
25*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint16_t> halfWeights(4);
26*89c4ff92SAndroid Build Coastguard Worker     armnnUtils::FloatingPointConverter::ConvertFloat32To16(convWeightsData.data(), convWeightsData.size(),
27*89c4ff92SAndroid Build Coastguard Worker                                                            halfWeights.data());
28*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightInfo = armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true);
29*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightInfo, halfWeights);
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     //Create the simple test network
32*89c4ff92SAndroid Build Coastguard Worker     auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
33*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(info);
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     auto fc      = graph.AddLayer<armnn::FullyConnectedLayer>(armnn::FullyConnectedDescriptor(), "fc");
36*89c4ff92SAndroid Build Coastguard Worker     fc->GetOutputSlot().SetTensorInfo(info);
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     auto weightsLayer = graph.AddLayer<armnn::ConstantLayer>("weights");
39*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights);
40*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo);
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     //Connect up the layers
45*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(fc->GetInputSlot(0));
46*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().Connect(fc->GetInputSlot(1));
47*89c4ff92SAndroid Build Coastguard Worker     fc->GetOutputSlot().Connect(output->GetInputSlot(0));
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     //Test the tensor info is correct.
50*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float16);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker     // Run the optimizer
53*89c4ff92SAndroid Build Coastguard Worker     armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(ConvertConstantsHalfToFloat()));
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     //Test the tensor info is correct.
56*89c4ff92SAndroid Build Coastguard Worker     CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float32);
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     // Now test the data matches float32 data
59*89c4ff92SAndroid Build Coastguard Worker     const float* data = weightsLayer->m_LayerOutput->GetConstTensor<float>();
60*89c4ff92SAndroid Build Coastguard Worker     CHECK(1.0f == data[0]);
61*89c4ff92SAndroid Build Coastguard Worker     CHECK(2.0f == data[1]);
62*89c4ff92SAndroid Build Coastguard Worker     CHECK(3.0f == data[2]);
63*89c4ff92SAndroid Build Coastguard Worker     CHECK(4.0f == data[3]);
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker }