1 // 2 // Copyright © 2019 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <reference/workloads/ArgMinMax.hpp> 7 8 #include <doctest/doctest.h> 9 10 TEST_SUITE("RefArgMinMax") 11 { 12 TEST_CASE("ArgMinTest") 13 { 14 const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); 15 const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); 16 17 std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f}); 18 std::vector<int64_t> outputValues(outputInfo.GetNumElements()); 19 std::vector<int64_t> expectedValues({ 0, 1, 0 }); 20 21 ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), 22 outputValues.data(), 23 inputInfo, 24 outputInfo, 25 armnn::ArgMinMaxFunction::Min, 26 -2); 27 28 CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); 29 30 } 31 32 TEST_CASE("ArgMaxTest") 33 { 34 const armnn::TensorInfo inputInfo({ 1, 2, 3 } , armnn::DataType::Float32); 35 const armnn::TensorInfo outputInfo({ 1, 3 }, armnn::DataType::Signed64); 36 37 std::vector<float> inputValues({ 1.0f, 5.0f, 3.0f, 4.0f, 2.0f, 6.0f }); 38 std::vector<int64_t> outputValues(outputInfo.GetNumElements()); 39 std::vector<int64_t> expectedValues({ 1, 0, 1 }); 40 41 ArgMinMax(*armnn::MakeDecoder<float>(inputInfo, inputValues.data()), 42 outputValues.data(), 43 inputInfo, 44 outputInfo, 45 armnn::ArgMinMaxFunction::Max, 46 -2); 47 48 CHECK(std::equal(outputValues.begin(), outputValues.end(), expectedValues.begin(), expectedValues.end())); 49 50 } 51 52 }