xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/OptimizationViewsTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2019-2023  Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <SubgraphViewSelector.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/OptimizationViews.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/SubgraphView.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
21*89c4ff92SAndroid Build Coastguard Worker 
CheckLayers(Graph & graph)22*89c4ff92SAndroid Build Coastguard Worker void CheckLayers(Graph& graph)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker     unsigned int m_inputLayerCount = 0, m_outputLayerCount = 0, m_addLayerCount = 0;
25*89c4ff92SAndroid Build Coastguard Worker     for(auto layer : graph)
26*89c4ff92SAndroid Build Coastguard Worker     {
27*89c4ff92SAndroid Build Coastguard Worker         switch(layer->GetType())
28*89c4ff92SAndroid Build Coastguard Worker         {
29*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Input:
30*89c4ff92SAndroid Build Coastguard Worker                 ++m_inputLayerCount;
31*89c4ff92SAndroid Build Coastguard Worker                 CHECK((layer->GetName() == std::string("inLayer0") ||
32*89c4ff92SAndroid Build Coastguard Worker                             layer->GetName() == std::string("inLayer1")));
33*89c4ff92SAndroid Build Coastguard Worker                 break;
34*89c4ff92SAndroid Build Coastguard Worker             // The Addition layer should become a PreCompiled Layer after Optimisation
35*89c4ff92SAndroid Build Coastguard Worker             case LayerType::PreCompiled:
36*89c4ff92SAndroid Build Coastguard Worker                 ++m_addLayerCount;
37*89c4ff92SAndroid Build Coastguard Worker                 CHECK(std::string(layer->GetName()) == "pre-compiled");
38*89c4ff92SAndroid Build Coastguard Worker                 break;
39*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Output:
40*89c4ff92SAndroid Build Coastguard Worker                 ++m_outputLayerCount;
41*89c4ff92SAndroid Build Coastguard Worker                 CHECK(std::string(layer->GetName()) == "outLayer");
42*89c4ff92SAndroid Build Coastguard Worker                 break;
43*89c4ff92SAndroid Build Coastguard Worker             default:
44*89c4ff92SAndroid Build Coastguard Worker                 //Fail for anything else
45*89c4ff92SAndroid Build Coastguard Worker                 CHECK(false);
46*89c4ff92SAndroid Build Coastguard Worker         }
47*89c4ff92SAndroid Build Coastguard Worker     }
48*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_inputLayerCount == 2);
49*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_outputLayerCount == 1);
50*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_addLayerCount == 1);
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OptimizationViewsTestSuite")
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCount")
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews view;
58*89c4ff92SAndroid Build Coastguard Worker     // Construct a graph with 3 layers
59*89c4ff92SAndroid Build Coastguard Worker     Graph baseGraph;
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convDescriptor;
64*89c4ff92SAndroid Build Coastguard Worker     PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
65*89c4ff92SAndroid Build Coastguard Worker     Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
66*89c4ff92SAndroid Build Coastguard Worker     Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
67*89c4ff92SAndroid Build Coastguard Worker     Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
68*89c4ff92SAndroid Build Coastguard Worker     Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
69*89c4ff92SAndroid Build Coastguard Worker     Layer* const substitutableCompiledLayer =
70*89c4ff92SAndroid Build Coastguard Worker             baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
75*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
76*89c4ff92SAndroid Build Coastguard Worker     convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
77*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
78*89c4ff92SAndroid Build Coastguard Worker     convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a failed layer
81*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr failedSubgraph =
82*89c4ff92SAndroid Build Coastguard Worker         CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
83*89c4ff92SAndroid Build Coastguard Worker                                CreateOutputsFrom({convLayer1}),
84*89c4ff92SAndroid Build Coastguard Worker                                {convLayer1});
85*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for an untouched layer
86*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr untouchedSubgraph =
87*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
88*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
89*89c4ff92SAndroid Build Coastguard Worker                                    {convLayer2});
90*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a substitutable layer
91*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr substitutableSubgraph =
92*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
93*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
94*89c4ff92SAndroid Build Coastguard Worker                                    {substitutableCompiledLayer});
95*89c4ff92SAndroid Build Coastguard Worker     // Create a Graph containing a layer to substitute in
96*89c4ff92SAndroid Build Coastguard Worker     Graph substitutableGraph;
97*89c4ff92SAndroid Build Coastguard Worker     Layer* const substitutionpreCompiledLayer =
98*89c4ff92SAndroid Build Coastguard Worker             substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a substitution layer
101*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr substitutionSubgraph =
102*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
103*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({substitutionpreCompiledLayer}),
104*89c4ff92SAndroid Build Coastguard Worker                                    {substitutionpreCompiledLayer});
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     // Sub in the graph
107*89c4ff92SAndroid Build Coastguard Worker     baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
110*89c4ff92SAndroid Build Coastguard Worker     view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr baseSubgraph =
113*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
114*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
115*89c4ff92SAndroid Build Coastguard Worker                                    {substitutionpreCompiledLayer});
116*89c4ff92SAndroid Build Coastguard Worker     view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker     // Construct original subgraph to compare against
119*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr originalSubgraph =
120*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
121*89c4ff92SAndroid Build Coastguard Worker             CreateOutputsFrom({convLayer2}),
122*89c4ff92SAndroid Build Coastguard Worker             {convLayer1, convLayer2, substitutionpreCompiledLayer});
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker     CHECK(view.Validate(*originalSubgraph));
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCountUsingGetINetwork")
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews view;
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const inputLayer = view.GetINetwork()->AddInputLayer(0, "input");
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor convDescriptor;
135*89c4ff92SAndroid Build Coastguard Worker     PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
136*89c4ff92SAndroid Build Coastguard Worker     CompiledBlobPtr blobPtr;
137*89c4ff92SAndroid Build Coastguard Worker     BackendId backend = Compute::CpuRef;
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker     Layer* convLayer1 = PolymorphicDowncast<Layer*>(
140*89c4ff92SAndroid Build Coastguard Worker         view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
141*89c4ff92SAndroid Build Coastguard Worker                                                            "conv1"));
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker     Layer* convLayer2 = PolymorphicDowncast<Layer*>(
144*89c4ff92SAndroid Build Coastguard Worker         view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
145*89c4ff92SAndroid Build Coastguard Worker                                                            "conv2"));
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const outputLayer = view.GetINetwork()->AddOutputLayer(0, "output");
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
150*89c4ff92SAndroid Build Coastguard Worker     convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
151*89c4ff92SAndroid Build Coastguard Worker     convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a failed layer
154*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr failedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
155*89c4ff92SAndroid Build Coastguard Worker                                                                                   CreateOutputsFrom({convLayer1}),
156*89c4ff92SAndroid Build Coastguard Worker                                                                                   {convLayer1});
157*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for an untouched layer
158*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr untouchedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
159*89c4ff92SAndroid Build Coastguard Worker                                                                                      CreateOutputsFrom({convLayer2}),
160*89c4ff92SAndroid Build Coastguard Worker                                                                                      {convLayer2});
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     // Create a Network containing a layer to substitute in
163*89c4ff92SAndroid Build Coastguard Worker     NetworkImpl net;
164*89c4ff92SAndroid Build Coastguard Worker     Layer* substitutionpreCompiledLayer = PolymorphicDowncast<Layer*>(
165*89c4ff92SAndroid Build Coastguard Worker         net.AddPrecompiledLayer(substitutionLayerDescriptor, std::move(blobPtr), backend));
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a substitution layer
168*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr substitutionSubgraph =
169*89c4ff92SAndroid Build Coastguard Worker         CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
170*89c4ff92SAndroid Build Coastguard Worker                                                 CreateOutputsFrom({substitutionpreCompiledLayer}),
171*89c4ff92SAndroid Build Coastguard Worker                                                 {substitutionpreCompiledLayer});
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker     view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
174*89c4ff92SAndroid Build Coastguard Worker     view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
175*89c4ff92SAndroid Build Coastguard Worker 
176*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr baseSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
177*89c4ff92SAndroid Build Coastguard Worker                                                                                 CreateOutputsFrom({convLayer2}),
178*89c4ff92SAndroid Build Coastguard Worker                                                                                 {substitutionpreCompiledLayer});
179*89c4ff92SAndroid Build Coastguard Worker     view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     // Construct original subgraph to compare against
182*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr originalSubgraph =
183*89c4ff92SAndroid Build Coastguard Worker         CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
184*89c4ff92SAndroid Build Coastguard Worker                                                 CreateOutputsFrom({convLayer2}),
185*89c4ff92SAndroid Build Coastguard Worker                                                 {convLayer1, convLayer2, substitutionpreCompiledLayer});
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker     CHECK(view.Validate(*originalSubgraph));
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker 
190*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCountFailValidate")
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews view;
193*89c4ff92SAndroid Build Coastguard Worker     // Construct a graph with 3 layers
194*89c4ff92SAndroid Build Coastguard Worker     Graph baseGraph;
195*89c4ff92SAndroid Build Coastguard Worker 
196*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convDescriptor;
199*89c4ff92SAndroid Build Coastguard Worker     PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
200*89c4ff92SAndroid Build Coastguard Worker     Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
201*89c4ff92SAndroid Build Coastguard Worker     Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
202*89c4ff92SAndroid Build Coastguard Worker     Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
203*89c4ff92SAndroid Build Coastguard Worker     Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
204*89c4ff92SAndroid Build Coastguard Worker     Layer* const substitutableCompiledLayer =
205*89c4ff92SAndroid Build Coastguard Worker             baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
211*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
212*89c4ff92SAndroid Build Coastguard Worker     convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
213*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
214*89c4ff92SAndroid Build Coastguard Worker     convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for an untouched layer
217*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr untouchedSubgraph =
218*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
219*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
220*89c4ff92SAndroid Build Coastguard Worker                                    {convLayer2});
221*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a substitutable layer
222*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr substitutableSubgraph =
223*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
224*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
225*89c4ff92SAndroid Build Coastguard Worker                                    {substitutableCompiledLayer});
226*89c4ff92SAndroid Build Coastguard Worker     // Create a Graph containing a layer to substitute in
227*89c4ff92SAndroid Build Coastguard Worker     Graph substitutableGraph;
228*89c4ff92SAndroid Build Coastguard Worker     Layer* const substitutionpreCompiledLayer =
229*89c4ff92SAndroid Build Coastguard Worker             substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     // Subgraph for a substitution layer
232*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr substitutionSubgraph =
233*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
234*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({substitutionpreCompiledLayer}),
235*89c4ff92SAndroid Build Coastguard Worker                                    {substitutionpreCompiledLayer});
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     // Sub in the graph
238*89c4ff92SAndroid Build Coastguard Worker     baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
239*89c4ff92SAndroid Build Coastguard Worker 
240*89c4ff92SAndroid Build Coastguard Worker     view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr baseSubgraph =
243*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
244*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
245*89c4ff92SAndroid Build Coastguard Worker                                    {substitutionpreCompiledLayer});
246*89c4ff92SAndroid Build Coastguard Worker     view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker     // Construct original subgraph to compare against
249*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr originalSubgraph =
250*89c4ff92SAndroid Build Coastguard Worker             CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
251*89c4ff92SAndroid Build Coastguard Worker                                    CreateOutputsFrom({convLayer2}),
252*89c4ff92SAndroid Build Coastguard Worker                                    {convLayer1, convLayer2, substitutionpreCompiledLayer});
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker     // Validate should fail as convLayer1 is not counted
255*89c4ff92SAndroid Build Coastguard Worker     CHECK(!view.Validate(*originalSubgraph));
256*89c4ff92SAndroid Build Coastguard Worker }
257*89c4ff92SAndroid Build Coastguard Worker 
258*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeViewsValidateDeviceMockBackend")
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker     // build up the structure of the network
261*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr net(armnn::INetwork::Create());
262*89c4ff92SAndroid Build Coastguard Worker 
263*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* input = net->AddInputLayer(0, "inLayer0");
264*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* input1 = net->AddInputLayer(1, "inLayer1");
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_BEGIN
267*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* addition = net->AddAdditionLayer("addLayer");
268*89c4ff92SAndroid Build Coastguard Worker     ARMNN_NO_DEPRECATE_WARN_END
269*89c4ff92SAndroid Build Coastguard Worker 
270*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* output = net->AddOutputLayer(0, "outLayer");
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
273*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
274*89c4ff92SAndroid Build Coastguard Worker     addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
277*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
278*89c4ff92SAndroid Build Coastguard Worker     addition->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker     armnn::MockBackendInitialiser initialiser;
281*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntime::CreationOptions options;
282*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
283*89c4ff92SAndroid Build Coastguard Worker 
284*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { MockBackend().GetIdStatic() };
285*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
286*89c4ff92SAndroid Build Coastguard Worker     CHECK(optNet);
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     // Check the optimised graph
289*89c4ff92SAndroid Build Coastguard Worker     armnn::Graph& graph = GetGraphForTesting(optNet.get());
290*89c4ff92SAndroid Build Coastguard Worker     CheckLayers(graph);
291*89c4ff92SAndroid Build Coastguard Worker }
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsReturnsINetworkReference")
294*89c4ff92SAndroid Build Coastguard Worker {
295*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews view;
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker     auto layer = view.GetINetworkRef().AddInputLayer(0, "input");
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker     // Check layer has been added to the referenced INetwork
300*89c4ff92SAndroid Build Coastguard Worker     CHECK(layer);
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker }
305