1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2019-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <SubgraphViewSelector.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/OptimizationViews.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/SubgraphView.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
21*89c4ff92SAndroid Build Coastguard Worker
CheckLayers(Graph & graph)22*89c4ff92SAndroid Build Coastguard Worker void CheckLayers(Graph& graph)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker unsigned int m_inputLayerCount = 0, m_outputLayerCount = 0, m_addLayerCount = 0;
25*89c4ff92SAndroid Build Coastguard Worker for(auto layer : graph)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker switch(layer->GetType())
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
30*89c4ff92SAndroid Build Coastguard Worker ++m_inputLayerCount;
31*89c4ff92SAndroid Build Coastguard Worker CHECK((layer->GetName() == std::string("inLayer0") ||
32*89c4ff92SAndroid Build Coastguard Worker layer->GetName() == std::string("inLayer1")));
33*89c4ff92SAndroid Build Coastguard Worker break;
34*89c4ff92SAndroid Build Coastguard Worker // The Addition layer should become a PreCompiled Layer after Optimisation
35*89c4ff92SAndroid Build Coastguard Worker case LayerType::PreCompiled:
36*89c4ff92SAndroid Build Coastguard Worker ++m_addLayerCount;
37*89c4ff92SAndroid Build Coastguard Worker CHECK(std::string(layer->GetName()) == "pre-compiled");
38*89c4ff92SAndroid Build Coastguard Worker break;
39*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
40*89c4ff92SAndroid Build Coastguard Worker ++m_outputLayerCount;
41*89c4ff92SAndroid Build Coastguard Worker CHECK(std::string(layer->GetName()) == "outLayer");
42*89c4ff92SAndroid Build Coastguard Worker break;
43*89c4ff92SAndroid Build Coastguard Worker default:
44*89c4ff92SAndroid Build Coastguard Worker //Fail for anything else
45*89c4ff92SAndroid Build Coastguard Worker CHECK(false);
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker }
48*89c4ff92SAndroid Build Coastguard Worker CHECK(m_inputLayerCount == 2);
49*89c4ff92SAndroid Build Coastguard Worker CHECK(m_outputLayerCount == 1);
50*89c4ff92SAndroid Build Coastguard Worker CHECK(m_addLayerCount == 1);
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker
53*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OptimizationViewsTestSuite")
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCount")
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker OptimizationViews view;
58*89c4ff92SAndroid Build Coastguard Worker // Construct a graph with 3 layers
59*89c4ff92SAndroid Build Coastguard Worker Graph baseGraph;
60*89c4ff92SAndroid Build Coastguard Worker
61*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
62*89c4ff92SAndroid Build Coastguard Worker
63*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convDescriptor;
64*89c4ff92SAndroid Build Coastguard Worker PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
65*89c4ff92SAndroid Build Coastguard Worker Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
66*89c4ff92SAndroid Build Coastguard Worker Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
67*89c4ff92SAndroid Build Coastguard Worker Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
68*89c4ff92SAndroid Build Coastguard Worker Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
69*89c4ff92SAndroid Build Coastguard Worker Layer* const substitutableCompiledLayer =
70*89c4ff92SAndroid Build Coastguard Worker baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
75*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
76*89c4ff92SAndroid Build Coastguard Worker convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
77*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
78*89c4ff92SAndroid Build Coastguard Worker convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a failed layer
81*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr failedSubgraph =
82*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
83*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer1}),
84*89c4ff92SAndroid Build Coastguard Worker {convLayer1});
85*89c4ff92SAndroid Build Coastguard Worker // Subgraph for an untouched layer
86*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr untouchedSubgraph =
87*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
88*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
89*89c4ff92SAndroid Build Coastguard Worker {convLayer2});
90*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a substitutable layer
91*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr substitutableSubgraph =
92*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
93*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
94*89c4ff92SAndroid Build Coastguard Worker {substitutableCompiledLayer});
95*89c4ff92SAndroid Build Coastguard Worker // Create a Graph containing a layer to substitute in
96*89c4ff92SAndroid Build Coastguard Worker Graph substitutableGraph;
97*89c4ff92SAndroid Build Coastguard Worker Layer* const substitutionpreCompiledLayer =
98*89c4ff92SAndroid Build Coastguard Worker substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a substitution layer
101*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr substitutionSubgraph =
102*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
103*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({substitutionpreCompiledLayer}),
104*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker // Sub in the graph
107*89c4ff92SAndroid Build Coastguard Worker baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
110*89c4ff92SAndroid Build Coastguard Worker view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr baseSubgraph =
113*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
114*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
115*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
116*89c4ff92SAndroid Build Coastguard Worker view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
117*89c4ff92SAndroid Build Coastguard Worker
118*89c4ff92SAndroid Build Coastguard Worker // Construct original subgraph to compare against
119*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr originalSubgraph =
120*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
121*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
122*89c4ff92SAndroid Build Coastguard Worker {convLayer1, convLayer2, substitutionpreCompiledLayer});
123*89c4ff92SAndroid Build Coastguard Worker
124*89c4ff92SAndroid Build Coastguard Worker CHECK(view.Validate(*originalSubgraph));
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker
127*89c4ff92SAndroid Build Coastguard Worker
128*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCountUsingGetINetwork")
129*89c4ff92SAndroid Build Coastguard Worker {
130*89c4ff92SAndroid Build Coastguard Worker OptimizationViews view;
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const inputLayer = view.GetINetwork()->AddInputLayer(0, "input");
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dDescriptor convDescriptor;
135*89c4ff92SAndroid Build Coastguard Worker PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
136*89c4ff92SAndroid Build Coastguard Worker CompiledBlobPtr blobPtr;
137*89c4ff92SAndroid Build Coastguard Worker BackendId backend = Compute::CpuRef;
138*89c4ff92SAndroid Build Coastguard Worker
139*89c4ff92SAndroid Build Coastguard Worker Layer* convLayer1 = PolymorphicDowncast<Layer*>(
140*89c4ff92SAndroid Build Coastguard Worker view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
141*89c4ff92SAndroid Build Coastguard Worker "conv1"));
142*89c4ff92SAndroid Build Coastguard Worker
143*89c4ff92SAndroid Build Coastguard Worker Layer* convLayer2 = PolymorphicDowncast<Layer*>(
144*89c4ff92SAndroid Build Coastguard Worker view.GetINetwork()->AddDepthwiseConvolution2dLayer(convDescriptor,
145*89c4ff92SAndroid Build Coastguard Worker "conv2"));
146*89c4ff92SAndroid Build Coastguard Worker
147*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* const outputLayer = view.GetINetwork()->AddOutputLayer(0, "output");
148*89c4ff92SAndroid Build Coastguard Worker
149*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
150*89c4ff92SAndroid Build Coastguard Worker convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
151*89c4ff92SAndroid Build Coastguard Worker convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
152*89c4ff92SAndroid Build Coastguard Worker
153*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a failed layer
154*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr failedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
155*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer1}),
156*89c4ff92SAndroid Build Coastguard Worker {convLayer1});
157*89c4ff92SAndroid Build Coastguard Worker // Subgraph for an untouched layer
158*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr untouchedSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
159*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
160*89c4ff92SAndroid Build Coastguard Worker {convLayer2});
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker // Create a Network containing a layer to substitute in
163*89c4ff92SAndroid Build Coastguard Worker NetworkImpl net;
164*89c4ff92SAndroid Build Coastguard Worker Layer* substitutionpreCompiledLayer = PolymorphicDowncast<Layer*>(
165*89c4ff92SAndroid Build Coastguard Worker net.AddPrecompiledLayer(substitutionLayerDescriptor, std::move(blobPtr), backend));
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a substitution layer
168*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr substitutionSubgraph =
169*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
170*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({substitutionpreCompiledLayer}),
171*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
172*89c4ff92SAndroid Build Coastguard Worker
173*89c4ff92SAndroid Build Coastguard Worker view.AddFailedSubgraph(SubgraphView(*failedSubgraph));
174*89c4ff92SAndroid Build Coastguard Worker view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr baseSubgraph = CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
177*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
178*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
179*89c4ff92SAndroid Build Coastguard Worker view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
180*89c4ff92SAndroid Build Coastguard Worker
181*89c4ff92SAndroid Build Coastguard Worker // Construct original subgraph to compare against
182*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr originalSubgraph =
183*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
184*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
185*89c4ff92SAndroid Build Coastguard Worker {convLayer1, convLayer2, substitutionpreCompiledLayer});
186*89c4ff92SAndroid Build Coastguard Worker
187*89c4ff92SAndroid Build Coastguard Worker CHECK(view.Validate(*originalSubgraph));
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker
190*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsSubgraphLayerCountFailValidate")
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker OptimizationViews view;
193*89c4ff92SAndroid Build Coastguard Worker // Construct a graph with 3 layers
194*89c4ff92SAndroid Build Coastguard Worker Graph baseGraph;
195*89c4ff92SAndroid Build Coastguard Worker
196*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = baseGraph.AddLayer<InputLayer>(0, "input");
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convDescriptor;
199*89c4ff92SAndroid Build Coastguard Worker PreCompiledDescriptor substitutionLayerDescriptor(2, 1);
200*89c4ff92SAndroid Build Coastguard Worker Layer* const convLayer1 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv1");
201*89c4ff92SAndroid Build Coastguard Worker Layer* const convLayer2 = baseGraph.AddLayer<Convolution2dLayer>(convDescriptor, "conv2");
202*89c4ff92SAndroid Build Coastguard Worker Layer* const weightsLayer1 = baseGraph.AddLayer<ConstantLayer>("weights1");
203*89c4ff92SAndroid Build Coastguard Worker Layer* const weightsLayer2 = baseGraph.AddLayer<ConstantLayer>("weights2");
204*89c4ff92SAndroid Build Coastguard Worker Layer* const substitutableCompiledLayer =
205*89c4ff92SAndroid Build Coastguard Worker baseGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
206*89c4ff92SAndroid Build Coastguard Worker
207*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = baseGraph.AddLayer<OutputLayer>(0, "output");
208*89c4ff92SAndroid Build Coastguard Worker
209*89c4ff92SAndroid Build Coastguard Worker
210*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(0));
211*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(convLayer1->GetInputSlot(1));
212*89c4ff92SAndroid Build Coastguard Worker convLayer1->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(0));
213*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(convLayer2->GetInputSlot(1));
214*89c4ff92SAndroid Build Coastguard Worker convLayer2->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
215*89c4ff92SAndroid Build Coastguard Worker
216*89c4ff92SAndroid Build Coastguard Worker // Subgraph for an untouched layer
217*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr untouchedSubgraph =
218*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer2),
219*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
220*89c4ff92SAndroid Build Coastguard Worker {convLayer2});
221*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a substitutable layer
222*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr substitutableSubgraph =
223*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
224*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
225*89c4ff92SAndroid Build Coastguard Worker {substitutableCompiledLayer});
226*89c4ff92SAndroid Build Coastguard Worker // Create a Graph containing a layer to substitute in
227*89c4ff92SAndroid Build Coastguard Worker Graph substitutableGraph;
228*89c4ff92SAndroid Build Coastguard Worker Layer* const substitutionpreCompiledLayer =
229*89c4ff92SAndroid Build Coastguard Worker substitutableGraph.AddLayer<PreCompiledLayer>(substitutionLayerDescriptor, "pre-compiled");
230*89c4ff92SAndroid Build Coastguard Worker
231*89c4ff92SAndroid Build Coastguard Worker // Subgraph for a substitution layer
232*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr substitutionSubgraph =
233*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(substitutionpreCompiledLayer),
234*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({substitutionpreCompiledLayer}),
235*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker // Sub in the graph
238*89c4ff92SAndroid Build Coastguard Worker baseGraph.SubstituteSubgraph(*substitutableSubgraph, *substitutionSubgraph);
239*89c4ff92SAndroid Build Coastguard Worker
240*89c4ff92SAndroid Build Coastguard Worker view.AddUntouchedSubgraph(SubgraphView(*untouchedSubgraph));
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr baseSubgraph =
243*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
244*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
245*89c4ff92SAndroid Build Coastguard Worker {substitutionpreCompiledLayer});
246*89c4ff92SAndroid Build Coastguard Worker view.AddSubstitution({*baseSubgraph, *substitutionSubgraph});
247*89c4ff92SAndroid Build Coastguard Worker
248*89c4ff92SAndroid Build Coastguard Worker // Construct original subgraph to compare against
249*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr originalSubgraph =
250*89c4ff92SAndroid Build Coastguard Worker CreateSubgraphViewFrom(CreateInputsFrom(convLayer1),
251*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer2}),
252*89c4ff92SAndroid Build Coastguard Worker {convLayer1, convLayer2, substitutionpreCompiledLayer});
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker // Validate should fail as convLayer1 is not counted
255*89c4ff92SAndroid Build Coastguard Worker CHECK(!view.Validate(*originalSubgraph));
256*89c4ff92SAndroid Build Coastguard Worker }
257*89c4ff92SAndroid Build Coastguard Worker
258*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeViewsValidateDeviceMockBackend")
259*89c4ff92SAndroid Build Coastguard Worker {
260*89c4ff92SAndroid Build Coastguard Worker // build up the structure of the network
261*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr net(armnn::INetwork::Create());
262*89c4ff92SAndroid Build Coastguard Worker
263*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* input = net->AddInputLayer(0, "inLayer0");
264*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* input1 = net->AddInputLayer(1, "inLayer1");
265*89c4ff92SAndroid Build Coastguard Worker
266*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_BEGIN
267*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* addition = net->AddAdditionLayer("addLayer");
268*89c4ff92SAndroid Build Coastguard Worker ARMNN_NO_DEPRECATE_WARN_END
269*89c4ff92SAndroid Build Coastguard Worker
270*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* output = net->AddOutputLayer(0, "outLayer");
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).Connect(addition->GetInputSlot(0));
273*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).Connect(addition->GetInputSlot(1));
274*89c4ff92SAndroid Build Coastguard Worker addition->GetOutputSlot(0).Connect(output->GetInputSlot(0));
275*89c4ff92SAndroid Build Coastguard Worker
276*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
277*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
278*89c4ff92SAndroid Build Coastguard Worker addition->GetOutputSlot(0).SetTensorInfo(armnn::TensorInfo({ 1, 1, 4, 4 }, armnn::DataType::Float32));
279*89c4ff92SAndroid Build Coastguard Worker
280*89c4ff92SAndroid Build Coastguard Worker armnn::MockBackendInitialiser initialiser;
281*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime::CreationOptions options;
282*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
283*89c4ff92SAndroid Build Coastguard Worker
284*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = { MockBackend().GetIdStatic() };
285*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet = armnn::Optimize(*net, backends, runtime->GetDeviceSpec());
286*89c4ff92SAndroid Build Coastguard Worker CHECK(optNet);
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker // Check the optimised graph
289*89c4ff92SAndroid Build Coastguard Worker armnn::Graph& graph = GetGraphForTesting(optNet.get());
290*89c4ff92SAndroid Build Coastguard Worker CheckLayers(graph);
291*89c4ff92SAndroid Build Coastguard Worker }
292*89c4ff92SAndroid Build Coastguard Worker
293*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizedViewsReturnsINetworkReference")
294*89c4ff92SAndroid Build Coastguard Worker {
295*89c4ff92SAndroid Build Coastguard Worker OptimizationViews view;
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker auto layer = view.GetINetworkRef().AddInputLayer(0, "input");
298*89c4ff92SAndroid Build Coastguard Worker
299*89c4ff92SAndroid Build Coastguard Worker // Check layer has been added to the referenced INetwork
300*89c4ff92SAndroid Build Coastguard Worker CHECK(layer);
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker
303*89c4ff92SAndroid Build Coastguard Worker
304*89c4ff92SAndroid Build Coastguard Worker }
305