xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2021, 2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnTypeInput>
CreateElementwiseUnaryNetwork(const TensorShape & inputShape,const TensorShape & outputShape,UnaryOperation operation,const float qScale=1.0f,const int32_t qOffset=0)23*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateElementwiseUnaryNetwork(const TensorShape& inputShape,
24*89c4ff92SAndroid Build Coastguard Worker                                           const TensorShape& outputShape,
25*89c4ff92SAndroid Build Coastguard Worker                                           UnaryOperation operation,
26*89c4ff92SAndroid Build Coastguard Worker                                           const float qScale = 1.0f,
27*89c4ff92SAndroid Build Coastguard Worker                                           const int32_t qOffset = 0)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net(INetwork::Create());
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     ElementwiseUnaryDescriptor descriptor(operation);
34*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* elementwiseUnaryLayer = net->AddElementwiseUnaryLayer(descriptor, "elementwiseUnary");
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo(inputShape, ArmnnTypeInput, qScale, qOffset, true);
37*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* input = net->AddInputLayer(armnn::numeric_cast<LayerBindingId>(0));
38*89c4ff92SAndroid Build Coastguard Worker     Connect(input, elementwiseUnaryLayer, inputTensorInfo, 0, 0);
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo(outputShape, ArmnnTypeInput, qScale, qOffset);
41*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* output = net->AddOutputLayer(0, "output");
42*89c4ff92SAndroid Build Coastguard Worker     Connect(elementwiseUnaryLayer, output, outputTensorInfo, 0, 0);
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker     return net;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnInType,
48*89c4ff92SAndroid Build Coastguard Worker          typename TInput = armnn::ResolveType<ArmnnInType>>
ElementwiseUnarySimpleEndToEnd(const std::vector<BackendId> & backends,UnaryOperation operation)49*89c4ff92SAndroid Build Coastguard Worker void ElementwiseUnarySimpleEndToEnd(const std::vector<BackendId>& backends,
50*89c4ff92SAndroid Build Coastguard Worker                                     UnaryOperation operation)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     const float   qScale  = IsQuantizedType<TInput>() ? 0.25f : 1.0f;
55*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = IsQuantizedType<TInput>() ? 50    : 0;
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& inputShape  = { 2, 2, 2, 2 };
58*89c4ff92SAndroid Build Coastguard Worker     const TensorShape& outputShape = { 2, 2, 2, 2 };
59*89c4ff92SAndroid Build Coastguard Worker 
60*89c4ff92SAndroid Build Coastguard Worker     // Builds up the structure of the network
61*89c4ff92SAndroid Build Coastguard Worker     INetworkPtr net = CreateElementwiseUnaryNetwork<ArmnnInType>(inputShape, outputShape, operation, qScale, qOffset);
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     CHECK(net);
64*89c4ff92SAndroid Build Coastguard Worker 
65*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> input;
66*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutput;
67*89c4ff92SAndroid Build Coastguard Worker     switch(operation)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Abs:
70*89c4ff92SAndroid Build Coastguard Worker             input = { 1, -1, 1, 1,  5, -5, 5, 5,
71*89c4ff92SAndroid Build Coastguard Worker                       -3, 3, 3, 3,  4, 4, -4, 4 };
72*89c4ff92SAndroid Build Coastguard Worker             expectedOutput = { 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
73*89c4ff92SAndroid Build Coastguard Worker                                3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f };
74*89c4ff92SAndroid Build Coastguard Worker             break;
75*89c4ff92SAndroid Build Coastguard Worker         case UnaryOperation::Rsqrt:
76*89c4ff92SAndroid Build Coastguard Worker             input = { 1, 1, 1, 1,  5, 5, 5, 5,
77*89c4ff92SAndroid Build Coastguard Worker                       3, 3, 3, 3,  4, 4, 4, 4 };
78*89c4ff92SAndroid Build Coastguard Worker             expectedOutput = { 1.f, 1.f, 1.f, 1.f, 0.447214f, 0.447214f, 0.447214f, 0.447214f,
79*89c4ff92SAndroid Build Coastguard Worker                                0.57735f, 0.57735f, 0.57735f, 0.57735f, 0.5f, 0.5f, 0.5f, 0.5f };
80*89c4ff92SAndroid Build Coastguard Worker             break;
81*89c4ff92SAndroid Build Coastguard Worker         default:
82*89c4ff92SAndroid Build Coastguard Worker             input = { 1, -1, 1, 1,  5, -5, 5, 5,
83*89c4ff92SAndroid Build Coastguard Worker                       -3, 3, 3, 3,  4, 4, -4, 4 };
84*89c4ff92SAndroid Build Coastguard Worker             expectedOutput = { 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
85*89c4ff92SAndroid Build Coastguard Worker                                3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f };
86*89c4ff92SAndroid Build Coastguard Worker             break;
87*89c4ff92SAndroid Build Coastguard Worker     }
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     // quantize data
91*89c4ff92SAndroid Build Coastguard Worker     std::vector<TInput> qInputData      = armnnUtils::QuantizedVector<TInput>(input, qScale, qOffset);
92*89c4ff92SAndroid Build Coastguard Worker     std::vector<TInput> qExpectedOutput = armnnUtils::QuantizedVector<TInput>(expectedOutput, qScale, qOffset);
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<TInput>> inputTensorData    = {{ 0, qInputData }};
95*89c4ff92SAndroid Build Coastguard Worker     std::map<int, std::vector<TInput>> expectedOutputData = {{ 0, qExpectedOutput }};
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnInType, ArmnnInType>(move(net), inputTensorData, expectedOutputData, backends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
101