1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace{
16*89c4ff92SAndroid Build Coastguard Worker
CreateChannelShuffleNetwork(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const armnn::ChannelShuffleDescriptor & descriptor)17*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateChannelShuffleNetwork(const armnn::TensorInfo& inputInfo,
18*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo,
19*89c4ff92SAndroid Build Coastguard Worker const armnn::ChannelShuffleDescriptor& descriptor)
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net(armnn::INetwork::Create());
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* inputLayer = net->AddInputLayer(0);
24*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* channelShuffleLayer = net->AddChannelShuffleLayer(descriptor, "channelShuffle");
25*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* outputLayer = net->AddOutputLayer(0, "output");
26*89c4ff92SAndroid Build Coastguard Worker Connect(inputLayer, channelShuffleLayer, inputInfo, 0, 0);
27*89c4ff92SAndroid Build Coastguard Worker Connect(channelShuffleLayer, outputLayer, outputInfo, 0, 0);
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker return net;
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker
32*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ChannelShuffleEndToEnd(const std::vector<BackendId> & backends)33*89c4ff92SAndroid Build Coastguard Worker void ChannelShuffleEndToEnd(const std::vector<BackendId>& backends)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo({ 3,12 }, ArmnnType);
36*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo({ 3,12 }, ArmnnType);
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetQuantizationScale(1.0f);
39*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetQuantizationOffset(0);
40*89c4ff92SAndroid Build Coastguard Worker inputInfo.SetConstant(true);
41*89c4ff92SAndroid Build Coastguard Worker outputInfo.SetQuantizationScale(1.0f);
42*89c4ff92SAndroid Build Coastguard Worker outputInfo.SetQuantizationOffset(0);
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker // Creates structures for input & output.
45*89c4ff92SAndroid Build Coastguard Worker std::vector<T> inputData{
46*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
47*89c4ff92SAndroid Build Coastguard Worker 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
48*89c4ff92SAndroid Build Coastguard Worker 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35
49*89c4ff92SAndroid Build Coastguard Worker };
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker std::vector<T> expectedOutput{
52*89c4ff92SAndroid Build Coastguard Worker 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11,
53*89c4ff92SAndroid Build Coastguard Worker 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23,
54*89c4ff92SAndroid Build Coastguard Worker 24, 28, 32, 25, 29, 33, 26, 30, 34, 27, 31, 35
55*89c4ff92SAndroid Build Coastguard Worker };
56*89c4ff92SAndroid Build Coastguard Worker ChannelShuffleDescriptor descriptor;
57*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Axis = 1;
58*89c4ff92SAndroid Build Coastguard Worker descriptor.m_NumGroups = 3;
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker // Builds up the structure of the network
61*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net = CreateChannelShuffleNetwork(inputInfo, outputInfo, descriptor);
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker CHECK(net);
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> inputTensorData = {{ 0, inputData }};
66*89c4ff92SAndroid Build Coastguard Worker std::map<int, std::vector<T>> expectedOutputData = {{ 0, expectedOutput }};
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker EndToEndLayerTestImpl<ArmnnType, ArmnnType>(move(net), inputTensorData, expectedOutputData, backends);
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
72