xref: /aosp_15_r20/external/armnn/src/armnn/test/optimizations/SquashEqualSiblingsTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer")
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations;
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SquashEqualSiblingsTest")
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker     armnn::Graph graph;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker     armnn::LayerBindingId outputId = 0;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 2, 3, 5 }, armnn::DataType::Float32);
25*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo permuted({ 1, 5, 2, 3 }, armnn::DataType::Float32);
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker     auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
28*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(info);
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     // Inserts equal permutes, equal reshapes and something else.
31*89c4ff92SAndroid Build Coastguard Worker     const armnn::PermuteDescriptor permDesc({ 0, 2, 3, 1 });
32*89c4ff92SAndroid Build Coastguard Worker     const armnn::ReshapeDescriptor reshapeDesc{ { 1, 3, 1, 5 } };
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* layer;
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker     layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, "");
37*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(permuted);
38*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0));
39*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, "");
42*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0));
43*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     layer = graph.AddLayer<armnn::FloorLayer>("");
46*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0));
47*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, "");
50*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0));
51*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, "");
54*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(permuted);
55*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0));
56*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(
59*89c4ff92SAndroid Build Coastguard Worker         graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, &IsLayerOfType<armnn::PermuteLayer>,
60*89c4ff92SAndroid Build Coastguard Worker         &IsLayerOfType<armnn::ReshapeLayer>, &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::ReshapeLayer>,
61*89c4ff92SAndroid Build Coastguard Worker         &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>,
62*89c4ff92SAndroid Build Coastguard Worker         &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>));
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker     armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualReshapeSiblings()));
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker     // The permutes and reshapes are squashed.
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
69*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::ReshapeLayer>,
70*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::OutputLayer>,
71*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>,
72*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>));
73*89c4ff92SAndroid Build Coastguard Worker }
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker }