1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker using namespace armnn; 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer") 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations; 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("SquashEqualSiblingsTest") 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker armnn::Graph graph; 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker armnn::LayerBindingId outputId = 0; 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo info({ 1, 2, 3, 5 }, armnn::DataType::Float32); 25*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo permuted({ 1, 5, 2, 3 }, armnn::DataType::Float32); 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker auto input = graph.AddLayer<armnn::InputLayer>(0, "input"); 28*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(info); 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker // Inserts equal permutes, equal reshapes and something else. 31*89c4ff92SAndroid Build Coastguard Worker const armnn::PermuteDescriptor permDesc({ 0, 2, 3, 1 }); 32*89c4ff92SAndroid Build Coastguard Worker const armnn::ReshapeDescriptor reshapeDesc{ { 1, 3, 1, 5 } }; 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* layer; 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, ""); 37*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(permuted); 38*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 39*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 40*89c4ff92SAndroid Build Coastguard Worker 41*89c4ff92SAndroid Build Coastguard Worker layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, ""); 42*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 43*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker layer = graph.AddLayer<armnn::FloorLayer>(""); 46*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 47*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, ""); 50*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 51*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, ""); 54*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(permuted); 55*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 56*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence( 59*89c4ff92SAndroid Build Coastguard Worker graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, &IsLayerOfType<armnn::PermuteLayer>, 60*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ReshapeLayer>, &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::ReshapeLayer>, 61*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, 62*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>)); 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualReshapeSiblings())); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker // The permutes and reshapes are squashed. 67*89c4ff92SAndroid Build Coastguard Worker 68*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, 69*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::ReshapeLayer>, 70*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::OutputLayer>, 71*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, 72*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>)); 73*89c4ff92SAndroid Build Coastguard Worker } 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker }