1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <TestUtils.hpp> 7 8 #include <Optimizer.hpp> 9 10 #include <doctest/doctest.h> 11 12 using namespace armnn; 13 14 TEST_SUITE("Optimizer") 15 { 16 using namespace armnn::optimizations; 17 18 TEST_CASE("SquashEqualSiblingsTest") 19 { 20 armnn::Graph graph; 21 22 armnn::LayerBindingId outputId = 0; 23 24 const armnn::TensorInfo info({ 1, 2, 3, 5 }, armnn::DataType::Float32); 25 const armnn::TensorInfo permuted({ 1, 5, 2, 3 }, armnn::DataType::Float32); 26 27 auto input = graph.AddLayer<armnn::InputLayer>(0, "input"); 28 input->GetOutputSlot().SetTensorInfo(info); 29 30 // Inserts equal permutes, equal reshapes and something else. 31 const armnn::PermuteDescriptor permDesc({ 0, 2, 3, 1 }); 32 const armnn::ReshapeDescriptor reshapeDesc{ { 1, 3, 1, 5 } }; 33 34 armnn::Layer* layer; 35 36 layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, ""); 37 layer->GetOutputSlot().SetTensorInfo(permuted); 38 layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 39 input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 40 41 layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, ""); 42 layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 43 input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 44 45 layer = graph.AddLayer<armnn::FloorLayer>(""); 46 layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 47 input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 48 49 layer = graph.AddLayer<armnn::ReshapeLayer>(reshapeDesc, ""); 50 layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 51 input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 52 53 layer = graph.AddLayer<armnn::PermuteLayer>(permDesc, ""); 54 layer->GetOutputSlot().SetTensorInfo(permuted); 55 layer->GetOutputSlot().Connect(graph.AddLayer<armnn::OutputLayer>(outputId++, "")->GetInputSlot(0)); 56 input->GetOutputSlot().Connect(layer->GetInputSlot(0)); 57 58 CHECK(CheckSequence( 59 graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, &IsLayerOfType<armnn::PermuteLayer>, 60 &IsLayerOfType<armnn::ReshapeLayer>, &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::ReshapeLayer>, 61 &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, 62 &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>)); 63 64 armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(SquashEqualPermuteSiblings(), SquashEqualReshapeSiblings())); 65 66 // The permutes and reshapes are squashed. 67 68 CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, 69 &IsLayerOfType<armnn::PermuteLayer>, &IsLayerOfType<armnn::ReshapeLayer>, 70 &IsLayerOfType<armnn::FloorLayer>, &IsLayerOfType<armnn::OutputLayer>, 71 &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>, 72 &IsLayerOfType<armnn::OutputLayer>, &IsLayerOfType<armnn::OutputLayer>)); 73 } 74 75 }