1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer") 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations; 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ConvertConstantsHalfToFloatTest") 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker armnn::Graph graph; 19*89c4ff92SAndroid Build Coastguard Worker 20*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo info({ 1, 1, 1, 2 }, armnn::DataType::Float32); 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker // Create the half precision input data 23*89c4ff92SAndroid Build Coastguard Worker unsigned int dims[] = { 4, 1, 1, 1 }; 24*89c4ff92SAndroid Build Coastguard Worker std::vector<float> convWeightsData{ 1.f, 2.f, 3.f, 4.f }; 25*89c4ff92SAndroid Build Coastguard Worker std::vector<uint16_t> halfWeights(4); 26*89c4ff92SAndroid Build Coastguard Worker armnnUtils::FloatingPointConverter::ConvertFloat32To16(convWeightsData.data(), convWeightsData.size(), 27*89c4ff92SAndroid Build Coastguard Worker halfWeights.data()); 28*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightInfo = armnn::TensorInfo(4, dims, armnn::DataType::Float16, 0.0f, 0, true); 29*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights(weightInfo, halfWeights); 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker //Create the simple test network 32*89c4ff92SAndroid Build Coastguard Worker auto input = graph.AddLayer<armnn::InputLayer>(0, "input"); 33*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(info); 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker auto fc = graph.AddLayer<armnn::FullyConnectedLayer>(armnn::FullyConnectedDescriptor(), "fc"); 36*89c4ff92SAndroid Build Coastguard Worker fc->GetOutputSlot().SetTensorInfo(info); 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = graph.AddLayer<armnn::ConstantLayer>("weights"); 39*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights); 40*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightInfo); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker auto output = graph.AddLayer<armnn::OutputLayer>(1, "output"); 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker //Connect up the layers 45*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(fc->GetInputSlot(0)); 46*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().Connect(fc->GetInputSlot(1)); 47*89c4ff92SAndroid Build Coastguard Worker fc->GetOutputSlot().Connect(output->GetInputSlot(0)); 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker //Test the tensor info is correct. 50*89c4ff92SAndroid Build Coastguard Worker CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float16); 51*89c4ff92SAndroid Build Coastguard Worker 52*89c4ff92SAndroid Build Coastguard Worker // Run the optimizer 53*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(ConvertConstantsHalfToFloat())); 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker //Test the tensor info is correct. 56*89c4ff92SAndroid Build Coastguard Worker CHECK(weightsLayer->m_LayerOutput->GetTensorInfo().GetDataType() == armnn::DataType::Float32); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker // Now test the data matches float32 data 59*89c4ff92SAndroid Build Coastguard Worker const float* data = weightsLayer->m_LayerOutput->GetConstTensor<float>(); 60*89c4ff92SAndroid Build Coastguard Worker CHECK(1.0f == data[0]); 61*89c4ff92SAndroid Build Coastguard Worker CHECK(2.0f == data[1]); 62*89c4ff92SAndroid Build Coastguard Worker CHECK(3.0f == data[2]); 63*89c4ff92SAndroid Build Coastguard Worker CHECK(4.0f == data[3]); 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker }