xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/layerTests/CastTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnnTestUtils/LayerTestResult.hpp>
9 
10 #include <ResolveType.hpp>
11 
12 #include <armnn/backends/IBackendInternal.hpp>
13 #include <armnn/backends/WorkloadFactory.hpp>
14 #include <Half.hpp>
15 
16 template<armnn::DataType inputDataType, armnn::DataType outputDataType,
17         typename TInput=armnn::ResolveType<inputDataType>,
18         typename TOutput=armnn::ResolveType<outputDataType>>
19 LayerTestResult<TOutput, 4> CastTest(armnn::IWorkloadFactory& workloadFactory,
20                                      const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
21                                      const armnn::ITensorHandleFactory& tensorHandleFactory,
22                                      const std::vector<TInput>& inputTensor,
23                                      const std::vector<TOutput>& outputTensor);
24 
25 
26 LayerTestResult<float, 4> CastInt32ToFloat2dTest(
27         armnn::IWorkloadFactory& workloadFactory,
28         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
29         const armnn::ITensorHandleFactory& tensorHandleFactory);
30 
31 LayerTestResult<float, 4> CastInt16ToFloat2dTest(
32         armnn::IWorkloadFactory& workloadFactory,
33         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
34         const armnn::ITensorHandleFactory& tensorHandleFactory);
35 
36 LayerTestResult<float, 4> CastInt8ToFloat2dTest(
37         armnn::IWorkloadFactory& workloadFactory,
38         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
39         const armnn::ITensorHandleFactory& tensorHandleFactory);
40 
41 LayerTestResult<float, 4> CastInt8AsymmToFloat2dTest(
42         armnn::IWorkloadFactory& workloadFactory,
43         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
44         const armnn::ITensorHandleFactory& tensorHandleFactory);
45 
46 LayerTestResult<float, 4> CastUInt8ToFloat2dTest(
47         armnn::IWorkloadFactory& workloadFactory,
48         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
49         const armnn::ITensorHandleFactory& tensorHandleFactory);
50 
51 LayerTestResult<uint8_t, 4> CastInt8ToUInt82dTest(
52         armnn::IWorkloadFactory& workloadFactory,
53         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
54         const armnn::ITensorHandleFactory& tensorHandleFactory);
55 
56 LayerTestResult<uint8_t, 4> CastInt8AsymmToUInt82dTest(
57         armnn::IWorkloadFactory& workloadFactory,
58         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
59         const armnn::ITensorHandleFactory& tensorHandleFactory);
60 
61 LayerTestResult<float, 4> CastFloat16ToFloat322dTest(
62         armnn::IWorkloadFactory& workloadFactory,
63         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
64         const armnn::ITensorHandleFactory& tensorHandleFactory);
65 
66 LayerTestResult<float, 4> CastBFloat16ToFloat322dTest(
67         armnn::IWorkloadFactory& workloadFactory,
68         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
69         const armnn::ITensorHandleFactory& tensorHandleFactory);
70 
71 LayerTestResult<armnn::Half, 4> CastFloat32ToFloat162dTest(
72         armnn::IWorkloadFactory& workloadFactory,
73         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
74         const armnn::ITensorHandleFactory& tensorHandleFactory);
75 
76 LayerTestResult<int8_t , 4> CastFloat32ToInt82dTest(
77         armnn::IWorkloadFactory& workloadFactory,
78         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
79         const armnn::ITensorHandleFactory& tensorHandleFactory);
80 
81 LayerTestResult<uint8_t , 4> CastFloat32ToUInt82dTest(
82         armnn::IWorkloadFactory& workloadFactory,
83         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
84         const armnn::ITensorHandleFactory& tensorHandleFactory);
85