xref: /aosp_15_r20/external/android-nn-driver/test/TestHalfTensor.hpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ArmnnDriver.hpp>
9 #include "DriverTestHelpers.hpp"
10 
11 #include <half/half.hpp>
12 
13 using Half = half_float::half;
14 
15 namespace driverTestHelpers
16 {
17 
18 class TestHalfTensor
19 {
20 public:
TestHalfTensor(const armnn::TensorShape & shape,const std::vector<Half> & data)21     TestHalfTensor(const armnn::TensorShape & shape,
22                const std::vector<Half> & data)
23         : m_Shape{shape}
24         , m_Data{data}
25     {
26         DOCTEST_CHECK(m_Shape.GetNumElements() == m_Data.size());
27     }
28 
29     hidl_vec<uint32_t> GetDimensions() const;
30     unsigned int GetNumElements() const;
31     const Half * GetData() const;
32 
33 private:
34     armnn::TensorShape   m_Shape;
35     std::vector<Half>   m_Data;
36 };
37 
38 } // driverTestHelpers
39