1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020-2021, 2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker #include <vector>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker namespace
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnTypeInput>
CreateElementwiseUnaryNetwork(const TensorShape & inputShape,const TensorShape & outputShape,UnaryOperation operation,const float qScale=1.0f,const int32_t qOffset=0)23*89c4ff92SAndroid Build Coastguard Worker INetworkPtr CreateElementwiseUnaryNetwork(const TensorShape& inputShape,
24*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape,
25*89c4ff92SAndroid Build Coastguard Worker UnaryOperation operation,
26*89c4ff92SAndroid Build Coastguard Worker const float qScale = 1.0f,
27*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = 0)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net(INetwork::Create());
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker ElementwiseUnaryDescriptor descriptor(operation);
34*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* elementwiseUnaryLayer = net->AddElementwiseUnaryLayer(descriptor, "elementwiseUnary");
35*89c4ff92SAndroid Build Coastguard Worker
36*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputTensorInfo(inputShape, ArmnnTypeInput, qScale, qOffset, true);
37*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* input = net->AddInputLayer(armnn::numeric_cast<LayerBindingId>(0));
38*89c4ff92SAndroid Build Coastguard Worker Connect(input, elementwiseUnaryLayer, inputTensorInfo, 0, 0);
39*89c4ff92SAndroid Build Coastguard Worker
40*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputTensorInfo(outputShape, ArmnnTypeInput, qScale, qOffset);
41*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* output = net->AddOutputLayer(0, "output");
42*89c4ff92SAndroid Build Coastguard Worker Connect(elementwiseUnaryLayer, output, outputTensorInfo, 0, 0);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker return net;
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker
47*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnInType,
48*89c4ff92SAndroid Build Coastguard Worker typename TInput = armnn::ResolveType<ArmnnInType>>
ElementwiseUnarySimpleEndToEnd(const std::vector<BackendId> & backends,UnaryOperation operation)49*89c4ff92SAndroid Build Coastguard Worker void ElementwiseUnarySimpleEndToEnd(const std::vector<BackendId>& backends,
50*89c4ff92SAndroid Build Coastguard Worker UnaryOperation operation)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker const float qScale = IsQuantizedType<TInput>() ? 0.25f : 1.0f;
55*89c4ff92SAndroid Build Coastguard Worker const int32_t qOffset = IsQuantizedType<TInput>() ? 50 : 0;
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker const TensorShape& inputShape = { 2, 2, 2, 2 };
58*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outputShape = { 2, 2, 2, 2 };
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
61*89c4ff92SAndroid Build Coastguard Worker INetworkPtr net = CreateElementwiseUnaryNetwork<ArmnnInType>(inputShape, outputShape, operation, qScale, qOffset);
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker std::vector<float> input;
66*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutput;
67*89c4ff92SAndroid Build Coastguard Worker switch(operation)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Abs:
70*89c4ff92SAndroid Build Coastguard Worker input = { 1, -1, 1, 1, 5, -5, 5, 5,
71*89c4ff92SAndroid Build Coastguard Worker -3, 3, 3, 3, 4, 4, -4, 4 };
72*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
73*89c4ff92SAndroid Build Coastguard Worker 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f };
74*89c4ff92SAndroid Build Coastguard Worker break;
75*89c4ff92SAndroid Build Coastguard Worker case UnaryOperation::Rsqrt:
76*89c4ff92SAndroid Build Coastguard Worker input = { 1, 1, 1, 1, 5, 5, 5, 5,
77*89c4ff92SAndroid Build Coastguard Worker 3, 3, 3, 3, 4, 4, 4, 4 };
78*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 1.f, 1.f, 1.f, 1.f, 0.447214f, 0.447214f, 0.447214f, 0.447214f,
79*89c4ff92SAndroid Build Coastguard Worker 0.57735f, 0.57735f, 0.57735f, 0.57735f, 0.5f, 0.5f, 0.5f, 0.5f };
80*89c4ff92SAndroid Build Coastguard Worker break;
81*89c4ff92SAndroid Build Coastguard Worker default:
82*89c4ff92SAndroid Build Coastguard Worker input = { 1, -1, 1, 1, 5, -5, 5, 5,
83*89c4ff92SAndroid Build Coastguard Worker -3, 3, 3, 3, 4, 4, -4, 4 };
84*89c4ff92SAndroid Build Coastguard Worker expectedOutput = { 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
85*89c4ff92SAndroid Build Coastguard Worker 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f };
86*89c4ff92SAndroid Build Coastguard Worker break;
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker // quantize data
91*89c4ff92SAndroid Build Coastguard Worker std::vector<TInput> qInputData = armnnUtils::QuantizedVector<TInput>(input, qScale, qOffset);
92*89c4ff92SAndroid Build Coastguard Worker std::vector<TInput> qExpectedOutput = armnnUtils::QuantizedVector<TInput>(expectedOutput, qScale, qOffset);
93*89c4ff92SAndroid Build Coastguard Worker
94*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<TInput>> inputTensorData = {{ 0, qInputData }};
95*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<TInput>> expectedOutputData = {{ 0, qExpectedOutput }};
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnInType, ArmnnInType>(move(net), inputTensorData, expectedOutputData, backends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
101