xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/OptimizeSubgraphViewTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include "MockBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <unordered_map>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker namespace
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker // The expected number of layers, input and output slots in a subgraph after a test
24*89c4ff92SAndroid Build Coastguard Worker struct ExpectedSubgraphSize
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker     size_t m_NumInputSlots  = 0;
27*89c4ff92SAndroid Build Coastguard Worker     size_t m_NumOutputSlots = 0;
28*89c4ff92SAndroid Build Coastguard Worker     size_t m_NumLayers      = 0;
29*89c4ff92SAndroid Build Coastguard Worker };
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker // Keep the layers organized by layer name
32*89c4ff92SAndroid Build Coastguard Worker using LayerNameToLayerMap = std::unordered_map<std::string, Layer*>;
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker // Used to convert input and output slots from reference type (as stored in graphs) to
35*89c4ff92SAndroid Build Coastguard Worker // pointer type (as stored in subgraphs)
36*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType>
ConvertReferenceTypeToPointerType(const SlotType & input)37*89c4ff92SAndroid Build Coastguard Worker SlotType* ConvertReferenceTypeToPointerType(const SlotType& input)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     return const_cast<SlotType*>(&input);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker // Used to convert input and output slots from reference type (as stored in graphs) to
43*89c4ff92SAndroid Build Coastguard Worker // pointer type (as stored in subgraphs), array version
44*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType>
ConvertReferenceTypeToPointerType(const std::vector<SlotType> & input)45*89c4ff92SAndroid Build Coastguard Worker std::vector<SlotType*> ConvertReferenceTypeToPointerType(const std::vector<SlotType>& input)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker     std::vector<SlotType*> output;
48*89c4ff92SAndroid Build Coastguard Worker     std::transform(input.begin(),
49*89c4ff92SAndroid Build Coastguard Worker                    input.end(),
50*89c4ff92SAndroid Build Coastguard Worker                    std::back_inserter(output),
51*89c4ff92SAndroid Build Coastguard Worker                    [](const SlotType& inputItem)
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         return ConvertReferenceTypeToPointerType(inputItem);
54*89c4ff92SAndroid Build Coastguard Worker     });
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     return output;
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker // Convert from vector of Slots* (Input/Output) to vector of ISlots* (IInput/IOutput)
60*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType, typename ResultSlotType>
ConvertSlotsToISlots(const std::vector<SlotType * > input)61*89c4ff92SAndroid Build Coastguard Worker std::vector<ResultSlotType*> ConvertSlotsToISlots(const std::vector<SlotType*> input)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     std::vector<ResultSlotType*> output;
64*89c4ff92SAndroid Build Coastguard Worker     for (auto slot : input)
65*89c4ff92SAndroid Build Coastguard Worker     {
66*89c4ff92SAndroid Build Coastguard Worker         output.push_back(PolymorphicDowncast<ResultSlotType*>(slot));
67*89c4ff92SAndroid Build Coastguard Worker     }
68*89c4ff92SAndroid Build Coastguard Worker     return output;
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an input layer to a graph
AddInputLayer(Graph & graph,const std::string & layerName,const TensorInfo & inputInfo,LayerBindingId inputId=0)72*89c4ff92SAndroid Build Coastguard Worker Layer* AddInputLayer(Graph& graph,
73*89c4ff92SAndroid Build Coastguard Worker                      const std::string& layerName,
74*89c4ff92SAndroid Build Coastguard Worker                      const TensorInfo& inputInfo,
75*89c4ff92SAndroid Build Coastguard Worker                      LayerBindingId inputId = 0)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = graph.AddLayer<InputLayer>(inputId, layerName.c_str());
78*89c4ff92SAndroid Build Coastguard Worker     CHECK(inputLayer);
79*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
80*89c4ff92SAndroid Build Coastguard Worker     return inputLayer;
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an output layer to a graph
AddOutputLayer(Graph & graph,const std::string & layerName)84*89c4ff92SAndroid Build Coastguard Worker Layer* AddOutputLayer(Graph& graph,
85*89c4ff92SAndroid Build Coastguard Worker                       const std::string& layerName)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = graph.AddLayer<OutputLayer>(0, layerName.c_str());
88*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputLayer);
89*89c4ff92SAndroid Build Coastguard Worker     return outputLayer;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a convolution layer to a graph
AddConvolutionLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const Convolution2dDescriptor & convolutionDescriptor,const std::string & layerName,const TensorInfo & outputInfo)93*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* AddConvolutionLayer(Graph& graph,
94*89c4ff92SAndroid Build Coastguard Worker                                         LayerNameToLayerMap& layersInGraph,
95*89c4ff92SAndroid Build Coastguard Worker                                         const Convolution2dDescriptor& convolutionDescriptor,
96*89c4ff92SAndroid Build Coastguard Worker                                         const std::string& layerName,
97*89c4ff92SAndroid Build Coastguard Worker                                         const TensorInfo& outputInfo)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const convLayer = graph.AddLayer<Convolution2dLayer>(convolutionDescriptor, layerName.c_str());
100*89c4ff92SAndroid Build Coastguard Worker     CHECK(convLayer);
101*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
102*89c4ff92SAndroid Build Coastguard Worker     layersInGraph.insert(std::make_pair(convLayer->GetName(), convLayer));
103*89c4ff92SAndroid Build Coastguard Worker     return convLayer;
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a constant layer to a graph
AddConstantLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const std::string & layerName,const ConstTensor & constTensor,const TensorInfo & outputInfo)107*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* AddConstantLayer(Graph& graph,
108*89c4ff92SAndroid Build Coastguard Worker                                 LayerNameToLayerMap& layersInGraph,
109*89c4ff92SAndroid Build Coastguard Worker                                 const std::string& layerName,
110*89c4ff92SAndroid Build Coastguard Worker                                 const ConstTensor& constTensor,
111*89c4ff92SAndroid Build Coastguard Worker                                 const TensorInfo& outputInfo)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const constantLayer = graph.AddLayer<ConstantLayer>(layerName.c_str());
114*89c4ff92SAndroid Build Coastguard Worker     CHECK(constantLayer);
115*89c4ff92SAndroid Build Coastguard Worker     constantLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor);
116*89c4ff92SAndroid Build Coastguard Worker     constantLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
117*89c4ff92SAndroid Build Coastguard Worker     layersInGraph.insert(std::make_pair(constantLayer->GetName(), constantLayer));
118*89c4ff92SAndroid Build Coastguard Worker     return constantLayer;
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a pooling layer to a graph
AddPoolingLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const Pooling2dDescriptor & poolingDescriptor,const std::string & layerName,const TensorInfo & outputInfo)122*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* AddPoolingLayer(Graph& graph,
123*89c4ff92SAndroid Build Coastguard Worker                                 LayerNameToLayerMap& layersInGraph,
124*89c4ff92SAndroid Build Coastguard Worker                                 const Pooling2dDescriptor& poolingDescriptor,
125*89c4ff92SAndroid Build Coastguard Worker                                 const std::string& layerName,
126*89c4ff92SAndroid Build Coastguard Worker                                 const TensorInfo& outputInfo)
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const poolingLayer = graph.AddLayer<Pooling2dLayer>(poolingDescriptor, layerName.c_str());
129*89c4ff92SAndroid Build Coastguard Worker     CHECK(poolingLayer);
130*89c4ff92SAndroid Build Coastguard Worker     poolingLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
131*89c4ff92SAndroid Build Coastguard Worker     layersInGraph.insert(std::make_pair(poolingLayer->GetName(), poolingLayer));
132*89c4ff92SAndroid Build Coastguard Worker     return poolingLayer;
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an addition layer to a graph
AddAdditionaLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const std::string & layerName,const TensorInfo & outputInfo)136*89c4ff92SAndroid Build Coastguard Worker AdditionLayer* AddAdditionaLayer(Graph& graph,
137*89c4ff92SAndroid Build Coastguard Worker                                  LayerNameToLayerMap& layersInGraph,
138*89c4ff92SAndroid Build Coastguard Worker                                  const std::string& layerName,
139*89c4ff92SAndroid Build Coastguard Worker                                  const TensorInfo& outputInfo)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     AdditionLayer* const additionLayer = graph.AddLayer<AdditionLayer>(layerName.c_str());
142*89c4ff92SAndroid Build Coastguard Worker     CHECK(additionLayer);
143*89c4ff92SAndroid Build Coastguard Worker     additionLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
144*89c4ff92SAndroid Build Coastguard Worker     layersInGraph.insert(std::make_pair(additionLayer->GetName(), additionLayer));
145*89c4ff92SAndroid Build Coastguard Worker     return additionLayer;
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given substitution matches the specified expected values
CheckSubstitution(const OptimizationViews::SubstitutionPair & substitution,const ExpectedSubgraphSize & expectedSubstitutableSubgraphSize,const ExpectedSubgraphSize & expectedReplacementSubgraphSize,const SubgraphView::IInputSlots & expectedSubstitutableInputSlots,const SubgraphView::IOutputSlots & expectedSubstitutableOutputSlots,const SubgraphView::IConnectableLayers & expectedSubstitutableLayers)149*89c4ff92SAndroid Build Coastguard Worker void CheckSubstitution(const OptimizationViews::SubstitutionPair& substitution,
150*89c4ff92SAndroid Build Coastguard Worker                        const ExpectedSubgraphSize& expectedSubstitutableSubgraphSize,
151*89c4ff92SAndroid Build Coastguard Worker                        const ExpectedSubgraphSize& expectedReplacementSubgraphSize,
152*89c4ff92SAndroid Build Coastguard Worker                        const SubgraphView::IInputSlots& expectedSubstitutableInputSlots,
153*89c4ff92SAndroid Build Coastguard Worker                        const SubgraphView::IOutputSlots& expectedSubstitutableOutputSlots,
154*89c4ff92SAndroid Build Coastguard Worker                        const SubgraphView::IConnectableLayers& expectedSubstitutableLayers)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView& substitutableSubgraph = substitution.m_SubstitutableSubgraph;
157*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& substitutableSubgraphInputSlots = substitutableSubgraph.GetIInputSlots();
158*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& substitutableSubgraphOutputSlots = substitutableSubgraph.GetIOutputSlots();
159*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& substitutableSubgraphLayers =
160*89c4ff92SAndroid Build Coastguard Worker             substitutableSubgraph.GetIConnectableLayers();
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView& replacementSubgraph                          = substitution.m_ReplacementSubgraph;
163*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& replacementSubgraphInputSlots   = replacementSubgraph.GetIInputSlots();
164*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& replacementSubgraphOutputSlots = replacementSubgraph.GetIOutputSlots();
165*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& replacementSubgraphLayers = replacementSubgraph.GetIConnectableLayers();
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutableSubgraphInputSlots.size()  == expectedSubstitutableSubgraphSize.m_NumInputSlots);
168*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutableSubgraphOutputSlots.size() == expectedSubstitutableSubgraphSize.m_NumOutputSlots);
169*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutableSubgraphLayers.size()      == expectedSubstitutableSubgraphSize.m_NumLayers);
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(substitutableSubgraphInputSlots,  expectedSubstitutableInputSlots));
172*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(substitutableSubgraphOutputSlots, expectedSubstitutableOutputSlots));
173*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(substitutableSubgraphLayers,      expectedSubstitutableLayers));
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker     CHECK(replacementSubgraphInputSlots.size()  == expectedReplacementSubgraphSize.m_NumInputSlots);
176*89c4ff92SAndroid Build Coastguard Worker     CHECK(replacementSubgraphOutputSlots.size() == expectedReplacementSubgraphSize.m_NumOutputSlots);
177*89c4ff92SAndroid Build Coastguard Worker     CHECK(replacementSubgraphLayers.size()      == expectedReplacementSubgraphSize.m_NumLayers);
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker     CHECK(!AreEqual(replacementSubgraphInputSlots,  expectedSubstitutableInputSlots));
180*89c4ff92SAndroid Build Coastguard Worker     CHECK(!AreEqual(replacementSubgraphOutputSlots, expectedSubstitutableOutputSlots));
181*89c4ff92SAndroid Build Coastguard Worker     CHECK(!AreEqual(replacementSubgraphLayers,      expectedSubstitutableLayers));
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker     CHECK(std::all_of(replacementSubgraphLayers.begin(),
184*89c4ff92SAndroid Build Coastguard Worker                            replacementSubgraphLayers.end(),
185*89c4ff92SAndroid Build Coastguard Worker                            [](const IConnectableLayer* layer)
186*89c4ff92SAndroid Build Coastguard Worker     {
187*89c4ff92SAndroid Build Coastguard Worker         return layer->GetType() == LayerType::PreCompiled;
188*89c4ff92SAndroid Build Coastguard Worker     }));
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given failed subgraph matches the specified expected values
CheckFailedSubgraph(const SubgraphView & failedSubgraph,const ExpectedSubgraphSize & expectedFailedSubgraphSize,const SubgraphView::IInputSlots & expectedFailedInputSlots,const SubgraphView::IOutputSlots & expectedFailedOutputSlots,const SubgraphView::IConnectableLayers & expectedFailedLayers)192*89c4ff92SAndroid Build Coastguard Worker void CheckFailedSubgraph(const SubgraphView& failedSubgraph,
193*89c4ff92SAndroid Build Coastguard Worker                          const ExpectedSubgraphSize& expectedFailedSubgraphSize,
194*89c4ff92SAndroid Build Coastguard Worker                          const SubgraphView::IInputSlots& expectedFailedInputSlots,
195*89c4ff92SAndroid Build Coastguard Worker                          const SubgraphView::IOutputSlots& expectedFailedOutputSlots,
196*89c4ff92SAndroid Build Coastguard Worker                          const SubgraphView::IConnectableLayers& expectedFailedLayers)
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots&  failedSubgraphInputSlots  = failedSubgraph.GetIInputSlots();
199*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& failedSubgraphOutputSlots = failedSubgraph.GetIOutputSlots();
200*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& failedSubgraphLayers = failedSubgraph.GetIConnectableLayers();
201*89c4ff92SAndroid Build Coastguard Worker 
202*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphInputSlots.size()  == expectedFailedSubgraphSize.m_NumInputSlots);
203*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphOutputSlots.size() == expectedFailedSubgraphSize.m_NumOutputSlots);
204*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphLayers.size()      == expectedFailedSubgraphSize.m_NumLayers);
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(failedSubgraphInputSlots,  expectedFailedInputSlots));
207*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(failedSubgraphOutputSlots, expectedFailedOutputSlots));
208*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(failedSubgraphLayers,      expectedFailedLayers));
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given untouched subgraph matches the specified expected values
CheckUntouchedSubgraph(const SubgraphView & untouchedSubgraph,const ExpectedSubgraphSize & expectedUntouchedSubgraphSize,const SubgraphView::IInputSlots & expectedUntouchedInputSlots,const SubgraphView::IOutputSlots & expectedUntouchedOutputSlots,const SubgraphView::IConnectableLayers & expectedUntouchedLayers)212*89c4ff92SAndroid Build Coastguard Worker void CheckUntouchedSubgraph(const SubgraphView& untouchedSubgraph,
213*89c4ff92SAndroid Build Coastguard Worker                             const ExpectedSubgraphSize& expectedUntouchedSubgraphSize,
214*89c4ff92SAndroid Build Coastguard Worker                             const SubgraphView::IInputSlots& expectedUntouchedInputSlots,
215*89c4ff92SAndroid Build Coastguard Worker                             const SubgraphView::IOutputSlots& expectedUntouchedOutputSlots,
216*89c4ff92SAndroid Build Coastguard Worker                             const SubgraphView::IConnectableLayers& expectedUntouchedLayers)
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& untouchedSubgraphInputSlots = untouchedSubgraph.GetIInputSlots();
219*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& untouchedSubgraphOutputSlots = untouchedSubgraph.GetIOutputSlots();
220*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& untouchedSubgraphLayers = untouchedSubgraph.GetIConnectableLayers();
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphInputSlots.size()  == expectedUntouchedSubgraphSize.m_NumInputSlots);
223*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphOutputSlots.size() == expectedUntouchedSubgraphSize.m_NumOutputSlots);
224*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphLayers.size()      == expectedUntouchedSubgraphSize.m_NumLayers);
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(untouchedSubgraphInputSlots,  expectedUntouchedInputSlots));
227*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(untouchedSubgraphOutputSlots, expectedUntouchedOutputSlots));
228*89c4ff92SAndroid Build Coastguard Worker     CHECK(AreEqual(untouchedSubgraphLayers,      expectedUntouchedLayers));
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph containing only a single unsupported layer (only convolutions are unsupported by the mock backend)
BuildFullyUnsupportedSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)232*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnsupportedSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
233*89c4ff92SAndroid Build Coastguard Worker {
234*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
235*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
236*89c4ff92SAndroid Build Coastguard Worker 
237*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor poolingDescriptor;
238*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolType      = armnn::PoolingAlgorithm::Average;
239*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolWidth     = 2;
240*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolHeight    = 2;
241*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideX       = 2;
242*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideY       = 2;
243*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadLeft       = 1;
244*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadRight      = 1;
245*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadTop        = 1;
246*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadBottom     = 1;
247*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
248*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_DataLayout    = DataLayout::NHWC;
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
251*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
252*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const poolingLayer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
253*89c4ff92SAndroid Build Coastguard Worker                                                          "pooling layer", outputInfo);
254*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
257*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(poolingLayer->GetInputSlot(0));
258*89c4ff92SAndroid Build Coastguard Worker     poolingLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
261*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(poolingLayer),
262*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({poolingLayer}),
263*89c4ff92SAndroid Build Coastguard Worker                                   {poolingLayer});
264*89c4ff92SAndroid Build Coastguard Worker }
265*89c4ff92SAndroid Build Coastguard Worker 
266*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph containing only unsupported layers (only convolutions are unsupported by the mock backend)
BuildFullyUnsupportedSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)267*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnsupportedSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
270*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor poolingDescriptor;
273*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolType      = armnn::PoolingAlgorithm::Average;
274*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolWidth     = 2;
275*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolHeight    = 2;
276*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideX       = 2;
277*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideY       = 2;
278*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadLeft       = 1;
279*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadRight      = 1;
280*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadTop        = 1;
281*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadBottom     = 1;
282*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
283*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_DataLayout    = DataLayout::NHWC;
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
286*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
287*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling1Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
288*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling1 layer", outputInfo);
289*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling2Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
290*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling2 layer", outputInfo);
291*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling3Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
292*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling3 layer", outputInfo);
293*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
294*89c4ff92SAndroid Build Coastguard Worker 
295*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
296*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(pooling1Layer->GetInputSlot(0));
297*89c4ff92SAndroid Build Coastguard Worker     pooling1Layer->GetOutputSlot(0).Connect(pooling2Layer->GetInputSlot(0));
298*89c4ff92SAndroid Build Coastguard Worker     pooling2Layer->GetOutputSlot(0).Connect(pooling3Layer->GetInputSlot(0));
299*89c4ff92SAndroid Build Coastguard Worker     pooling3Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
302*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(pooling1Layer),
303*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({pooling3Layer}),
304*89c4ff92SAndroid Build Coastguard Worker                                   {pooling1Layer,
305*89c4ff92SAndroid Build Coastguard Worker                                    pooling2Layer,
306*89c4ff92SAndroid Build Coastguard Worker                                    pooling3Layer});
307*89c4ff92SAndroid Build Coastguard Worker }
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker // Creates a simple subgraph with only one convolution layer, supported by the mock backend
BuildFullyOptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)310*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyOptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
313*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
314*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
315*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
316*89c4ff92SAndroid Build Coastguard Worker 
317*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
318*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
319*89c4ff92SAndroid Build Coastguard Worker 
320*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
321*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
322*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
323*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
324*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
325*89c4ff92SAndroid Build Coastguard Worker 
326*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
327*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
328*89c4ff92SAndroid Build Coastguard Worker 
329*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
330*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
331*89c4ff92SAndroid Build Coastguard Worker 
332*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
333*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
334*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const convLayer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
335*89c4ff92SAndroid Build Coastguard Worker                                                               "conv layer", outputInfo);
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker   ConstantLayer* const weightsLayer =
338*89c4ff92SAndroid Build Coastguard Worker       AddConstantLayer(graph, layersInGraph, "Weights Layer", constWeightsTensor, weightInfo);
339*89c4ff92SAndroid Build Coastguard Worker   ConstantLayer* const biasLayer = AddConstantLayer(graph, layersInGraph, "Bias Layer", constBiasTensor, biasInfo);
340*89c4ff92SAndroid Build Coastguard Worker 
341*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
342*89c4ff92SAndroid Build Coastguard Worker 
343*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
344*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
345*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
346*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
347*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
350*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
351*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(convLayer, ignoreSlots),
352*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({convLayer}),
353*89c4ff92SAndroid Build Coastguard Worker                                   {convLayer, weightsLayer, biasLayer});
354*89c4ff92SAndroid Build Coastguard Worker }
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with five convolutions layers, all supported by the mock backend
BuildFullyOptimizableSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)357*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyOptimizableSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
358*89c4ff92SAndroid Build Coastguard Worker {
359*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
360*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
361*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
362*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo  ({  1,  1,  1, 16 }, DataType::Signed32,        0.9f, 0);
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
365*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
368*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
369*89c4ff92SAndroid Build Coastguard Worker 
370*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
371*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
372*89c4ff92SAndroid Build Coastguard Worker 
373*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
374*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
375*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
376*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
377*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
378*89c4ff92SAndroid Build Coastguard Worker 
379*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
380*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
381*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
382*89c4ff92SAndroid Build Coastguard Worker                                                                "conv1 layer", outputInfo);
383*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer1 =
384*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
385*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer1 =
386*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
389*89c4ff92SAndroid Build Coastguard Worker                                                                "conv2 layer", outputInfo);
390*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer2 =
391*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 2", constWeightsTensor, weightInfo);
392*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer2 =
393*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 2", constBiasTensor, biasInfo);
394*89c4ff92SAndroid Build Coastguard Worker 
395*89c4ff92SAndroid Build Coastguard Worker 
396*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
397*89c4ff92SAndroid Build Coastguard Worker                                                                "conv3 layer", outputInfo);
398*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer3 =
399*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
400*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer3 =
401*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
402*89c4ff92SAndroid Build Coastguard Worker 
403*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv4Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
404*89c4ff92SAndroid Build Coastguard Worker                                                                "conv4 layer", outputInfo);
405*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer4 =
406*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 4", constWeightsTensor, weightInfo);
407*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer4 =
408*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 4", constBiasTensor, biasInfo);
409*89c4ff92SAndroid Build Coastguard Worker 
410*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv5Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
411*89c4ff92SAndroid Build Coastguard Worker                                                                "conv5 layer", outputInfo);
412*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer5 =
413*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 5", constWeightsTensor, weightInfo);
414*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer5 =
415*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 5", constBiasTensor, biasInfo);
416*89c4ff92SAndroid Build Coastguard Worker 
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
419*89c4ff92SAndroid Build Coastguard Worker 
420*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
421*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
422*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
423*89c4ff92SAndroid Build Coastguard Worker     biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
424*89c4ff92SAndroid Build Coastguard Worker 
425*89c4ff92SAndroid Build Coastguard Worker     conv1Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
426*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
427*89c4ff92SAndroid Build Coastguard Worker     biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker     conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
430*89c4ff92SAndroid Build Coastguard Worker     weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
431*89c4ff92SAndroid Build Coastguard Worker     biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
432*89c4ff92SAndroid Build Coastguard Worker 
433*89c4ff92SAndroid Build Coastguard Worker     conv3Layer->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(0));
434*89c4ff92SAndroid Build Coastguard Worker     weightsLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(1));
435*89c4ff92SAndroid Build Coastguard Worker     biasLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(2));
436*89c4ff92SAndroid Build Coastguard Worker 
437*89c4ff92SAndroid Build Coastguard Worker     conv4Layer->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(0));
438*89c4ff92SAndroid Build Coastguard Worker     weightsLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(1));
439*89c4ff92SAndroid Build Coastguard Worker     biasLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(2));
440*89c4ff92SAndroid Build Coastguard Worker 
441*89c4ff92SAndroid Build Coastguard Worker     conv5Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
442*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
443*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
444*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
445*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({ conv5Layer }),
446*89c4ff92SAndroid Build Coastguard Worker                                   { weightsLayer1,
447*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer1,
448*89c4ff92SAndroid Build Coastguard Worker                                     conv1Layer,
449*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer2,
450*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer2,
451*89c4ff92SAndroid Build Coastguard Worker                                     conv2Layer,
452*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer3,
453*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer3,
454*89c4ff92SAndroid Build Coastguard Worker                                     conv3Layer,
455*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer4,
456*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer4,
457*89c4ff92SAndroid Build Coastguard Worker                                     conv4Layer,
458*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer5,
459*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer5,
460*89c4ff92SAndroid Build Coastguard Worker                                     conv5Layer });
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker 
463*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with both supported and unsupported layers
464*89c4ff92SAndroid Build Coastguard Worker // (only convolutions are unsupported by the mock backend)
BuildPartiallySupportedSubgraph(Graph & graph,LayerNameToLayerMap & layersInGraph)465*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallySupportedSubgraph(Graph& graph, LayerNameToLayerMap& layersInGraph)
466*89c4ff92SAndroid Build Coastguard Worker {
467*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
468*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
469*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
470*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo  ({  1,  1,  1, 16 }, DataType::Signed32,        0.9f, 0);
471*89c4ff92SAndroid Build Coastguard Worker 
472*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
473*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
474*89c4ff92SAndroid Build Coastguard Worker 
475*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
476*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
477*89c4ff92SAndroid Build Coastguard Worker 
478*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
479*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
480*89c4ff92SAndroid Build Coastguard Worker 
481*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
482*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
483*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
484*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
485*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
486*89c4ff92SAndroid Build Coastguard Worker 
487*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor poolingDescriptor;
488*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolType      = armnn::PoolingAlgorithm::Average;
489*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolWidth     = 2;
490*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PoolHeight    = 2;
491*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideX       = 2;
492*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_StrideY       = 2;
493*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadLeft       = 1;
494*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadRight      = 1;
495*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadTop        = 1;
496*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PadBottom     = 1;
497*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
498*89c4ff92SAndroid Build Coastguard Worker     poolingDescriptor.m_DataLayout    = DataLayout::NHWC;
499*89c4ff92SAndroid Build Coastguard Worker 
500*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
501*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
502*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer1 =
503*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
504*89c4ff92SAndroid Build Coastguard Worker 
505*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer1 =
506*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
507*89c4ff92SAndroid Build Coastguard Worker 
508*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
509*89c4ff92SAndroid Build Coastguard Worker                                                                "conv1 layer", outputInfo);
510*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling1Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
511*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling1 layer", outputInfo);
512*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling2Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
513*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling2 layer", outputInfo);
514*89c4ff92SAndroid Build Coastguard Worker 
515*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer2 =
516*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 2", constWeightsTensor, weightInfo);
517*89c4ff92SAndroid Build Coastguard Worker 
518*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer2 =
519*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 2", constBiasTensor, biasInfo);
520*89c4ff92SAndroid Build Coastguard Worker 
521*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
522*89c4ff92SAndroid Build Coastguard Worker                                                                "conv2 layer", outputInfo);
523*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* const pooling3Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
524*89c4ff92SAndroid Build Coastguard Worker                                                           "pooling3 layer", outputInfo);
525*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
528*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
529*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
530*89c4ff92SAndroid Build Coastguard Worker     biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
531*89c4ff92SAndroid Build Coastguard Worker     conv1Layer->GetOutputSlot(0).Connect(pooling1Layer->GetInputSlot(0));
532*89c4ff92SAndroid Build Coastguard Worker     pooling1Layer->GetOutputSlot(0).Connect(pooling2Layer->GetInputSlot(0));
533*89c4ff92SAndroid Build Coastguard Worker     pooling2Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
534*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
535*89c4ff92SAndroid Build Coastguard Worker     biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
536*89c4ff92SAndroid Build Coastguard Worker     conv2Layer->GetOutputSlot(0).Connect(pooling3Layer->GetInputSlot(0));
537*89c4ff92SAndroid Build Coastguard Worker     pooling3Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
538*89c4ff92SAndroid Build Coastguard Worker 
539*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
540*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
541*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
542*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({pooling3Layer}),
543*89c4ff92SAndroid Build Coastguard Worker                                   {weightsLayer1,
544*89c4ff92SAndroid Build Coastguard Worker                                    biasLayer1,
545*89c4ff92SAndroid Build Coastguard Worker                                    conv1Layer,
546*89c4ff92SAndroid Build Coastguard Worker                                    pooling1Layer,
547*89c4ff92SAndroid Build Coastguard Worker                                    pooling2Layer,
548*89c4ff92SAndroid Build Coastguard Worker                                    weightsLayer2,
549*89c4ff92SAndroid Build Coastguard Worker                                    biasLayer2,
550*89c4ff92SAndroid Build Coastguard Worker                                    conv2Layer,
551*89c4ff92SAndroid Build Coastguard Worker                                    pooling3Layer});
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker 
554*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with only unoptimizable layers ("unoptimizable" is added to the layer's name)
BuildFullyUnoptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)555*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnoptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
556*89c4ff92SAndroid Build Coastguard Worker {
557*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
558*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
559*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
560*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo  ({  1,  1,  1, 16 }, DataType::Signed32,        0.9f, 0);
561*89c4ff92SAndroid Build Coastguard Worker 
562*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
563*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
564*89c4ff92SAndroid Build Coastguard Worker 
565*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
566*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
567*89c4ff92SAndroid Build Coastguard Worker 
568*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
569*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
570*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
571*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
572*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
573*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
574*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
575*89c4ff92SAndroid Build Coastguard Worker 
576*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
577*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
578*89c4ff92SAndroid Build Coastguard Worker 
579*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer =
580*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer unoptimizable", constWeightsTensor, weightInfo);
581*89c4ff92SAndroid Build Coastguard Worker 
582*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer =
583*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer unoptimizable", constBiasTensor, biasInfo);
584*89c4ff92SAndroid Build Coastguard Worker 
585*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const convLayer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
586*89c4ff92SAndroid Build Coastguard Worker                                                                "conv layer unoptimizable", outputInfo);
587*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
588*89c4ff92SAndroid Build Coastguard Worker 
589*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
590*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
591*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
592*89c4ff92SAndroid Build Coastguard Worker     biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
593*89c4ff92SAndroid Build Coastguard Worker     convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
594*89c4ff92SAndroid Build Coastguard Worker 
595*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
596*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
597*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(convLayer, ignoreSlots),
598*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({convLayer}),
599*89c4ff92SAndroid Build Coastguard Worker                                   {convLayer, weightsLayer, biasLayer});
600*89c4ff92SAndroid Build Coastguard Worker }
601*89c4ff92SAndroid Build Coastguard Worker 
602*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with some unoptimizable layers ("unoptimizable" is added to the layer's name)
BuildPartiallyOptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)603*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallyOptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
604*89c4ff92SAndroid Build Coastguard Worker {
605*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
606*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
607*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
608*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo  ({  1,  1,  1, 16 }, DataType::Signed32,        0.9f, 0);
609*89c4ff92SAndroid Build Coastguard Worker 
610*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
611*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
612*89c4ff92SAndroid Build Coastguard Worker 
613*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
614*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
615*89c4ff92SAndroid Build Coastguard Worker 
616*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
617*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
618*89c4ff92SAndroid Build Coastguard Worker 
619*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
620*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
621*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
622*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
623*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
626*89c4ff92SAndroid Build Coastguard Worker     Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
627*89c4ff92SAndroid Build Coastguard Worker 
628*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer1 =
629*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
630*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer1 =
631*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
632*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer2 =
633*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 2 unoptimizable", constWeightsTensor, weightInfo);
634*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer2 =
635*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 2 unoptimizable", constBiasTensor, biasInfo);
636*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer3 =
637*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
638*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer3 =
639*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
640*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer4 =
641*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 4 unoptimizable", constWeightsTensor, weightInfo);
642*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer4 =
643*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 4 unoptimizable", constBiasTensor, biasInfo);
644*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer5 =
645*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 5", constWeightsTensor, weightInfo);
646*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer5 =
647*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 5", constBiasTensor, biasInfo);
648*89c4ff92SAndroid Build Coastguard Worker 
649*89c4ff92SAndroid Build Coastguard Worker   Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
650*89c4ff92SAndroid Build Coastguard Worker                                                              "conv1 layer", outputInfo);
651*89c4ff92SAndroid Build Coastguard Worker   Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
652*89c4ff92SAndroid Build Coastguard Worker                                                              "conv2 layer unoptimizable", outputInfo);
653*89c4ff92SAndroid Build Coastguard Worker   Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
654*89c4ff92SAndroid Build Coastguard Worker                                                              "conv3 layer", outputInfo);
655*89c4ff92SAndroid Build Coastguard Worker   Convolution2dLayer* const conv4Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
656*89c4ff92SAndroid Build Coastguard Worker                                                              "conv4 layer unoptimizable", outputInfo);
657*89c4ff92SAndroid Build Coastguard Worker   Convolution2dLayer* const conv5Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
658*89c4ff92SAndroid Build Coastguard Worker                                                              "conv5 layer", outputInfo);
659*89c4ff92SAndroid Build Coastguard Worker 
660*89c4ff92SAndroid Build Coastguard Worker   Layer* const outputLayer = AddOutputLayer(graph, "output layer");
661*89c4ff92SAndroid Build Coastguard Worker 
662*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
663*89c4ff92SAndroid Build Coastguard Worker     inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
664*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
665*89c4ff92SAndroid Build Coastguard Worker     biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
666*89c4ff92SAndroid Build Coastguard Worker 
667*89c4ff92SAndroid Build Coastguard Worker     conv1Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
668*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
669*89c4ff92SAndroid Build Coastguard Worker     biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
670*89c4ff92SAndroid Build Coastguard Worker 
671*89c4ff92SAndroid Build Coastguard Worker     conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
672*89c4ff92SAndroid Build Coastguard Worker     weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
673*89c4ff92SAndroid Build Coastguard Worker     biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
674*89c4ff92SAndroid Build Coastguard Worker 
675*89c4ff92SAndroid Build Coastguard Worker     conv3Layer->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(0));
676*89c4ff92SAndroid Build Coastguard Worker     weightsLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(1));
677*89c4ff92SAndroid Build Coastguard Worker     biasLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(2));
678*89c4ff92SAndroid Build Coastguard Worker 
679*89c4ff92SAndroid Build Coastguard Worker     conv4Layer->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(0));
680*89c4ff92SAndroid Build Coastguard Worker     weightsLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(1));
681*89c4ff92SAndroid Build Coastguard Worker     biasLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(2));
682*89c4ff92SAndroid Build Coastguard Worker 
683*89c4ff92SAndroid Build Coastguard Worker     conv5Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
684*89c4ff92SAndroid Build Coastguard Worker 
685*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
686*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
687*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
688*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({conv5Layer}),
689*89c4ff92SAndroid Build Coastguard Worker                                   {weightsLayer1,
690*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer1,
691*89c4ff92SAndroid Build Coastguard Worker                                     conv1Layer,
692*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer2,
693*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer2,
694*89c4ff92SAndroid Build Coastguard Worker                                     conv2Layer,
695*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer3,
696*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer3,
697*89c4ff92SAndroid Build Coastguard Worker                                     conv3Layer,
698*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer4,
699*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer4,
700*89c4ff92SAndroid Build Coastguard Worker                                     conv4Layer,
701*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer5,
702*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer5,
703*89c4ff92SAndroid Build Coastguard Worker                                     conv5Layer});
704*89c4ff92SAndroid Build Coastguard Worker }
705*89c4ff92SAndroid Build Coastguard Worker 
706*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with some input unoptimizable layers ("unoptimizable" is added to the layer's name),
707*89c4ff92SAndroid Build Coastguard Worker // this is meant to test input slots coming from different layers
BuildPartiallyOptimizableSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)708*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallyOptimizableSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
709*89c4ff92SAndroid Build Coastguard Worker {
710*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo inputInfo ({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
711*89c4ff92SAndroid Build Coastguard Worker     const TensorInfo outputInfo({  1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
712*89c4ff92SAndroid Build Coastguard Worker     TensorInfo weightInfo({ 16,  1,  1, 16 }, DataType::QAsymmU8, 0.9f, 0);
713*89c4ff92SAndroid Build Coastguard Worker     TensorInfo biasInfo  ({  1,  1,  1, 16 }, DataType::Signed32,        0.9f, 0);
714*89c4ff92SAndroid Build Coastguard Worker 
715*89c4ff92SAndroid Build Coastguard Worker     weightInfo.SetConstant(true);
716*89c4ff92SAndroid Build Coastguard Worker     biasInfo.SetConstant(true);
717*89c4ff92SAndroid Build Coastguard Worker 
718*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(64);
719*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constWeightsTensor(weightInfo, weightsVector);
720*89c4ff92SAndroid Build Coastguard Worker 
721*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasVector(16);
722*89c4ff92SAndroid Build Coastguard Worker     ConstTensor constBiasTensor(biasInfo, biasVector);
723*89c4ff92SAndroid Build Coastguard Worker 
724*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolutionDescriptor;
725*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideX     = 1;
726*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_StrideY     = 1;
727*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_BiasEnabled = true;
728*89c4ff92SAndroid Build Coastguard Worker     convolutionDescriptor.m_DataLayout  = DataLayout::NHWC;
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     // Construct the graph
731*89c4ff92SAndroid Build Coastguard Worker     Layer* const input1Layer = AddInputLayer(graph, "input1 layer", inputInfo, 0);
732*89c4ff92SAndroid Build Coastguard Worker     Layer* const input2Layer = AddInputLayer(graph, "input2 layer", inputInfo, 1);
733*89c4ff92SAndroid Build Coastguard Worker 
734*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer1 =
735*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
736*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer1 =
737*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
738*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer2 =
739*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 2 unoptimizable", constWeightsTensor, weightInfo);
740*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer2 =
741*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 2 unoptimizable", constBiasTensor, biasInfo);
742*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const weightsLayer3 =
743*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
744*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* const biasLayer3 =
745*89c4ff92SAndroid Build Coastguard Worker         AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
746*89c4ff92SAndroid Build Coastguard Worker 
747*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
748*89c4ff92SAndroid Build Coastguard Worker                                                                "conv1 layer", outputInfo);
749*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
750*89c4ff92SAndroid Build Coastguard Worker                                                                "conv2 layer unoptimizable", outputInfo);
751*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
752*89c4ff92SAndroid Build Coastguard Worker                                                                "conv3 layer", outputInfo);
753*89c4ff92SAndroid Build Coastguard Worker     AdditionLayer* const addLayer = AddAdditionaLayer(graph, layersInGraph, "add layer", outputInfo);
754*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputLayer = AddOutputLayer(graph, "output layer");
755*89c4ff92SAndroid Build Coastguard Worker 
756*89c4ff92SAndroid Build Coastguard Worker     // Connect the network
757*89c4ff92SAndroid Build Coastguard Worker     input1Layer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
758*89c4ff92SAndroid Build Coastguard Worker     weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
759*89c4ff92SAndroid Build Coastguard Worker     biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
760*89c4ff92SAndroid Build Coastguard Worker     conv1Layer->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0));
761*89c4ff92SAndroid Build Coastguard Worker 
762*89c4ff92SAndroid Build Coastguard Worker     input2Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
763*89c4ff92SAndroid Build Coastguard Worker     weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
764*89c4ff92SAndroid Build Coastguard Worker     biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
765*89c4ff92SAndroid Build Coastguard Worker     conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
766*89c4ff92SAndroid Build Coastguard Worker     weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
767*89c4ff92SAndroid Build Coastguard Worker     biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
768*89c4ff92SAndroid Build Coastguard Worker     conv3Layer->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1));
769*89c4ff92SAndroid Build Coastguard Worker 
770*89c4ff92SAndroid Build Coastguard Worker     addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
771*89c4ff92SAndroid Build Coastguard Worker 
772*89c4ff92SAndroid Build Coastguard Worker     // Create the subgraph view for the whole network
773*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> ignoreSlots = {1, 2};
774*89c4ff92SAndroid Build Coastguard Worker     return CreateSubgraphViewFrom(CreateInputsFrom({conv1Layer,
775*89c4ff92SAndroid Build Coastguard Worker                                                     conv2Layer}, ignoreSlots),
776*89c4ff92SAndroid Build Coastguard Worker                                   CreateOutputsFrom({addLayer}),
777*89c4ff92SAndroid Build Coastguard Worker                                   { weightsLayer1,
778*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer1,
779*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer2,
780*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer2,
781*89c4ff92SAndroid Build Coastguard Worker                                     weightsLayer3,
782*89c4ff92SAndroid Build Coastguard Worker                                     biasLayer3,
783*89c4ff92SAndroid Build Coastguard Worker                                     conv1Layer,
784*89c4ff92SAndroid Build Coastguard Worker                                     conv2Layer,
785*89c4ff92SAndroid Build Coastguard Worker                                     conv3Layer,
786*89c4ff92SAndroid Build Coastguard Worker                                     addLayer });
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker 
789*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only a single unsupported layer (only convolutions are unsupported by the mock backend)
FullyUnsupporteSubgraphTestImpl1()790*89c4ff92SAndroid Build Coastguard Worker void FullyUnsupporteSubgraphTestImpl1()
791*89c4ff92SAndroid Build Coastguard Worker {
792*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
793*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
794*89c4ff92SAndroid Build Coastguard Worker 
795*89c4ff92SAndroid Build Coastguard Worker     // Create an unsupported subgraph
796*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnsupportedSubgraph1(graph, layersInGraph);
797*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
798*89c4ff92SAndroid Build Coastguard Worker 
799*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
800*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
801*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
802*89c4ff92SAndroid Build Coastguard Worker 
803*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
804*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
805*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 1);
806*89c4ff92SAndroid Build Coastguard Worker 
807*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling layer"));
808*89c4ff92SAndroid Build Coastguard Worker 
809*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
810*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
811*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
812*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
813*89c4ff92SAndroid Build Coastguard Worker 
814*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
815*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
816*89c4ff92SAndroid Build Coastguard Worker 
817*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly, but no optimization is performed
818*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
819*89c4ff92SAndroid Build Coastguard Worker 
820*89c4ff92SAndroid Build Coastguard Worker     // =======================================================================
821*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
822*89c4ff92SAndroid Build Coastguard Worker     //  - No substitutions
823*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one failed subgraph, corresponding to the whole original one
824*89c4ff92SAndroid Build Coastguard Worker     //  - No untouched subgraphs
825*89c4ff92SAndroid Build Coastguard Worker     // =======================================================================
826*89c4ff92SAndroid Build Coastguard Worker 
827*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
828*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
829*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
830*89c4ff92SAndroid Build Coastguard Worker 
831*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetSubstitutions().empty());
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
834*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
835*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
836*89c4ff92SAndroid Build Coastguard Worker 
837*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs();
838*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphs.size() == 1);
839*89c4ff92SAndroid Build Coastguard Worker 
840*89c4ff92SAndroid Build Coastguard Worker     CheckFailedSubgraph(failedSubgraphs.at(0),
841*89c4ff92SAndroid Build Coastguard Worker                         { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
842*89c4ff92SAndroid Build Coastguard Worker                         subgraphInputSlots,
843*89c4ff92SAndroid Build Coastguard Worker                         subgraphOutputSlots,
844*89c4ff92SAndroid Build Coastguard Worker                         subgraphLayers);
845*89c4ff92SAndroid Build Coastguard Worker 
846*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
847*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
848*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
849*89c4ff92SAndroid Build Coastguard Worker 
850*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
851*89c4ff92SAndroid Build Coastguard Worker }
852*89c4ff92SAndroid Build Coastguard Worker 
853*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only unsupported layers (only convolutions are unsupported by the mock backend)
FullyUnsupporteSubgraphTestImpl2()854*89c4ff92SAndroid Build Coastguard Worker void FullyUnsupporteSubgraphTestImpl2()
855*89c4ff92SAndroid Build Coastguard Worker {
856*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
857*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
858*89c4ff92SAndroid Build Coastguard Worker 
859*89c4ff92SAndroid Build Coastguard Worker     // Create an unsupported subgraph
860*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnsupportedSubgraph2(graph, layersInGraph);
861*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
862*89c4ff92SAndroid Build Coastguard Worker 
863*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
864*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
865*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
866*89c4ff92SAndroid Build Coastguard Worker 
867*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
868*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
869*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 3);
870*89c4ff92SAndroid Build Coastguard Worker 
871*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling1 layer"));
872*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling2 layer"));
873*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling3 layer"));
874*89c4ff92SAndroid Build Coastguard Worker 
875*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
876*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
877*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
878*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
879*89c4ff92SAndroid Build Coastguard Worker 
880*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
881*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
882*89c4ff92SAndroid Build Coastguard Worker 
883*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly, but no optimization is performed
884*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
885*89c4ff92SAndroid Build Coastguard Worker 
886*89c4ff92SAndroid Build Coastguard Worker     // =======================================================================
887*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
888*89c4ff92SAndroid Build Coastguard Worker     //  - No substitutions
889*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one failed subgraph, corresponding to the whole original one
890*89c4ff92SAndroid Build Coastguard Worker     //  - No untouched subgraphs
891*89c4ff92SAndroid Build Coastguard Worker     // =======================================================================
892*89c4ff92SAndroid Build Coastguard Worker 
893*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
894*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
895*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
896*89c4ff92SAndroid Build Coastguard Worker 
897*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetSubstitutions().empty());
898*89c4ff92SAndroid Build Coastguard Worker 
899*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
900*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
901*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
902*89c4ff92SAndroid Build Coastguard Worker 
903*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs();
904*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphs.size() == 1);
905*89c4ff92SAndroid Build Coastguard Worker 
906*89c4ff92SAndroid Build Coastguard Worker     std::list<IConnectableLayer*> expectedFailedLayers{ layersInGraph.at("pooling1 layer"),
907*89c4ff92SAndroid Build Coastguard Worker                                             layersInGraph.at("pooling2 layer"),
908*89c4ff92SAndroid Build Coastguard Worker                                             layersInGraph.at("pooling3 layer") };
909*89c4ff92SAndroid Build Coastguard Worker 
910*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView& failedSubgraph = failedSubgraphs.at(0);
911*89c4ff92SAndroid Build Coastguard Worker 
912*89c4ff92SAndroid Build Coastguard Worker     CheckFailedSubgraph(failedSubgraph,
913*89c4ff92SAndroid Build Coastguard Worker                         { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
914*89c4ff92SAndroid Build Coastguard Worker                         subgraphInputSlots,
915*89c4ff92SAndroid Build Coastguard Worker                         subgraphOutputSlots,
916*89c4ff92SAndroid Build Coastguard Worker                         subgraphLayers);
917*89c4ff92SAndroid Build Coastguard Worker 
918*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& failedSubgraphLayers = failedSubgraph.GetIConnectableLayers();
919*89c4ff92SAndroid Build Coastguard Worker 
920*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(failedSubgraphLayers.front() + 0, expectedFailedLayers.front() + 0);
921*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(failedSubgraphLayers.front() + 1, expectedFailedLayers.front() + 1);
922*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(failedSubgraphLayers.front() + 2, expectedFailedLayers.front() + 2);
923*89c4ff92SAndroid Build Coastguard Worker 
924*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
925*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
926*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
927*89c4ff92SAndroid Build Coastguard Worker 
928*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
929*89c4ff92SAndroid Build Coastguard Worker }
930*89c4ff92SAndroid Build Coastguard Worker 
931*89c4ff92SAndroid Build Coastguard Worker // A simple case with only one layer (convolution) to optimize, supported by the mock backend
FullyOptimizableSubgraphTestImpl1()932*89c4ff92SAndroid Build Coastguard Worker void FullyOptimizableSubgraphTestImpl1()
933*89c4ff92SAndroid Build Coastguard Worker {
934*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
935*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
936*89c4ff92SAndroid Build Coastguard Worker 
937*89c4ff92SAndroid Build Coastguard Worker     // Create a fully optimizable subgraph
938*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyOptimizableSubgraph1(graph, layersInGraph);
939*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
940*89c4ff92SAndroid Build Coastguard Worker 
941*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
942*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
943*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
944*89c4ff92SAndroid Build Coastguard Worker 
945*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
946*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
947*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 3);
948*89c4ff92SAndroid Build Coastguard Worker 
949*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv layer"));
950*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer"));
951*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer"));
952*89c4ff92SAndroid Build Coastguard Worker 
953*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
954*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
955*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
956*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
957*89c4ff92SAndroid Build Coastguard Worker 
958*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
959*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
960*89c4ff92SAndroid Build Coastguard Worker 
961*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
962*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
963*89c4ff92SAndroid Build Coastguard Worker 
964*89c4ff92SAndroid Build Coastguard Worker     // ===========================================================================================
965*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
966*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one substitution, mapping the whole input subgraph to a new replacement subgraph
967*89c4ff92SAndroid Build Coastguard Worker     //  - No failed subgraphs
968*89c4ff92SAndroid Build Coastguard Worker     //  - No untouched subgraphs
969*89c4ff92SAndroid Build Coastguard Worker     // ===========================================================================================
970*89c4ff92SAndroid Build Coastguard Worker 
971*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
972*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
973*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
974*89c4ff92SAndroid Build Coastguard Worker 
975*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
976*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutions.size() == 1);
977*89c4ff92SAndroid Build Coastguard Worker 
978*89c4ff92SAndroid Build Coastguard Worker     CheckSubstitution(substitutions.at(0),
979*89c4ff92SAndroid Build Coastguard Worker                       { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
980*89c4ff92SAndroid Build Coastguard Worker                       { subgraphInputSlots.size(), subgraphOutputSlots.size(), 1 },
981*89c4ff92SAndroid Build Coastguard Worker                       subgraphInputSlots,
982*89c4ff92SAndroid Build Coastguard Worker                       subgraphOutputSlots,
983*89c4ff92SAndroid Build Coastguard Worker                       subgraphLayers);
984*89c4ff92SAndroid Build Coastguard Worker 
985*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
986*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
987*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
988*89c4ff92SAndroid Build Coastguard Worker 
989*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetFailedSubgraphs().empty());
990*89c4ff92SAndroid Build Coastguard Worker 
991*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
992*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
993*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
994*89c4ff92SAndroid Build Coastguard Worker 
995*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
996*89c4ff92SAndroid Build Coastguard Worker }
997*89c4ff92SAndroid Build Coastguard Worker 
998*89c4ff92SAndroid Build Coastguard Worker // A case with five layers (all convolutions) to optimize, all supported by the mock backend
FullyOptimizableSubgraphTestImpl2()999*89c4ff92SAndroid Build Coastguard Worker void FullyOptimizableSubgraphTestImpl2()
1000*89c4ff92SAndroid Build Coastguard Worker {
1001*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
1002*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
1003*89c4ff92SAndroid Build Coastguard Worker 
1004*89c4ff92SAndroid Build Coastguard Worker     // Create a fully optimizable subgraph
1005*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyOptimizableSubgraph2(graph, layersInGraph);
1006*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
1007*89c4ff92SAndroid Build Coastguard Worker 
1008*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1009*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1010*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1011*89c4ff92SAndroid Build Coastguard Worker 
1012*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
1013*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
1014*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphPtr->GetIConnectableLayers().size() == 15);
1015*89c4ff92SAndroid Build Coastguard Worker 
1016*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv1 layer"));
1017*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv2 layer"));
1018*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv3 layer"));
1019*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv4 layer"));
1020*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv5 layer"));
1021*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 1"));
1022*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 2"));
1023*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 3"));
1024*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 4"));
1025*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 5"));
1026*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 1"));
1027*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 2"));
1028*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 3"));
1029*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 4"));
1030*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 5"));
1031*89c4ff92SAndroid Build Coastguard Worker 
1032*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
1033*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
1034*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
1035*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
1036*89c4ff92SAndroid Build Coastguard Worker 
1037*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
1038*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
1039*89c4ff92SAndroid Build Coastguard Worker 
1040*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
1041*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1042*89c4ff92SAndroid Build Coastguard Worker 
1043*89c4ff92SAndroid Build Coastguard Worker     // ===========================================================================================
1044*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
1045*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one substitution, mapping the whole input subgraph to a new replacement subgraph
1046*89c4ff92SAndroid Build Coastguard Worker     //  - No failed subgraphs
1047*89c4ff92SAndroid Build Coastguard Worker     //  - No untouched subgraphs
1048*89c4ff92SAndroid Build Coastguard Worker     // ===========================================================================================
1049*89c4ff92SAndroid Build Coastguard Worker 
1050*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1051*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
1052*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1053*89c4ff92SAndroid Build Coastguard Worker 
1054*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
1055*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutions.size() == 1);
1056*89c4ff92SAndroid Build Coastguard Worker 
1057*89c4ff92SAndroid Build Coastguard Worker     std::list<IConnectableLayer*> expectedSubstitutableLayers{
1058*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Weights Layer 1"),
1059*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Weights Layer 2"),
1060*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Weights Layer 3"),
1061*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Weights Layer 4"),
1062*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Weights Layer 5"),
1063*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Bias Layer 1"),
1064*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Bias Layer 2"),
1065*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Bias Layer 3"),
1066*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Bias Layer 4"),
1067*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("Bias Layer 5"),
1068*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("conv1 layer"),
1069*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("conv2 layer"),
1070*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("conv3 layer"),
1071*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("conv4 layer"),
1072*89c4ff92SAndroid Build Coastguard Worker                                                    layersInGraph.at("conv5 layer")};
1073*89c4ff92SAndroid Build Coastguard Worker 
1074*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::SubstitutionPair& substitution = substitutions.at(0);
1075*89c4ff92SAndroid Build Coastguard Worker 
1076*89c4ff92SAndroid Build Coastguard Worker     CheckSubstitution(
1077*89c4ff92SAndroid Build Coastguard Worker         substitution,
1078*89c4ff92SAndroid Build Coastguard Worker         {subgraphInputSlots.size(), subgraphOutputSlots.size(),
1079*89c4ff92SAndroid Build Coastguard Worker          subgraphLayers.size()},
1080*89c4ff92SAndroid Build Coastguard Worker         {subgraphInputSlots.size(), subgraphOutputSlots.size(), 1},
1081*89c4ff92SAndroid Build Coastguard Worker         subgraphInputSlots, subgraphOutputSlots, expectedSubstitutableLayers);
1082*89c4ff92SAndroid Build Coastguard Worker 
1083*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& substitutableSubgraphLayers =
1084*89c4ff92SAndroid Build Coastguard Worker             substitution.m_SubstitutableSubgraph.GetIConnectableLayers();
1085*89c4ff92SAndroid Build Coastguard Worker 
1086*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(substitutableSubgraphLayers.front() + 0, expectedSubstitutableLayers.front() + 0);
1087*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(substitutableSubgraphLayers.front() + 1, expectedSubstitutableLayers.front() + 1);
1088*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(substitutableSubgraphLayers.front() + 2, expectedSubstitutableLayers.front() + 2);
1089*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(substitutableSubgraphLayers.front() + 3, expectedSubstitutableLayers.front() + 3);
1090*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(substitutableSubgraphLayers.front() + 4, expectedSubstitutableLayers.front() + 4);
1091*89c4ff92SAndroid Build Coastguard Worker 
1092*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1093*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
1094*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1095*89c4ff92SAndroid Build Coastguard Worker 
1096*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetFailedSubgraphs().empty());
1097*89c4ff92SAndroid Build Coastguard Worker 
1098*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1099*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
1100*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1101*89c4ff92SAndroid Build Coastguard Worker 
1102*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
1103*89c4ff92SAndroid Build Coastguard Worker }
1104*89c4ff92SAndroid Build Coastguard Worker 
1105*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contaions both supported and unsupported layers
1106*89c4ff92SAndroid Build Coastguard Worker // (but only convolutions are unsupported by the mock backend)
PartiallySupportedSubgraphTestImpl()1107*89c4ff92SAndroid Build Coastguard Worker void PartiallySupportedSubgraphTestImpl()
1108*89c4ff92SAndroid Build Coastguard Worker {
1109*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
1110*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
1111*89c4ff92SAndroid Build Coastguard Worker 
1112*89c4ff92SAndroid Build Coastguard Worker     // Create a fully optimizable subgraph
1113*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallySupportedSubgraph(graph, layersInGraph);
1114*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
1115*89c4ff92SAndroid Build Coastguard Worker 
1116*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1117*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1118*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1119*89c4ff92SAndroid Build Coastguard Worker 
1120*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
1121*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
1122*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 9);
1123*89c4ff92SAndroid Build Coastguard Worker 
1124*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 1"));
1125*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 1"));
1126*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv1 layer"));
1127*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling1 layer"));
1128*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling2 layer"));
1129*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Weights Layer 2"));
1130*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "Bias Layer 2"));
1131*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv2 layer"));
1132*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "pooling3 layer"));
1133*89c4ff92SAndroid Build Coastguard Worker 
1134*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
1135*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
1136*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
1137*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
1138*89c4ff92SAndroid Build Coastguard Worker 
1139*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
1140*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
1141*89c4ff92SAndroid Build Coastguard Worker 
1142*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
1143*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1144*89c4ff92SAndroid Build Coastguard Worker 
1145*89c4ff92SAndroid Build Coastguard Worker     // ========================================================================
1146*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
1147*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly two substitution, corresponding to the supported layers
1148*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly two failed subgraphs, corresponding to the unsupported layers
1149*89c4ff92SAndroid Build Coastguard Worker     //  - No untouched subgraphs
1150*89c4ff92SAndroid Build Coastguard Worker     // ========================================================================
1151*89c4ff92SAndroid Build Coastguard Worker 
1152*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1153*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
1154*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1155*89c4ff92SAndroid Build Coastguard Worker 
1156*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
1157*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutions.size() == 2);
1158*89c4ff92SAndroid Build Coastguard Worker     // Sort into a consistent order
1159*89c4ff92SAndroid Build Coastguard Worker     std::sort(substitutions.begin(), substitutions.end(), [](auto s1, auto s2) {
1160*89c4ff92SAndroid Build Coastguard Worker         return strcmp(s1.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName(),
1161*89c4ff92SAndroid Build Coastguard Worker                       s2.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName()) < 0;
1162*89c4ff92SAndroid Build Coastguard Worker     });
1163*89c4ff92SAndroid Build Coastguard Worker 
1164*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 3 },
1165*89c4ff92SAndroid Build Coastguard Worker                                                                           { 1, 1, 3 } };
1166*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 },
1167*89c4ff92SAndroid Build Coastguard Worker                                                                         { 1, 1, 1 } };
1168*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IInputSlots> expectedSubstitutableInputSlots
1169*89c4ff92SAndroid Build Coastguard Worker     {
1170*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<InputSlot, IInputSlot>(
1171*89c4ff92SAndroid Build Coastguard Worker                 {ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlot(0))}),
1172*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<InputSlot, IInputSlot>(
1173*89c4ff92SAndroid Build Coastguard Worker                 {ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer")->GetInputSlot(0))})
1174*89c4ff92SAndroid Build Coastguard Worker     };
1175*89c4ff92SAndroid Build Coastguard Worker 
1176*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IOutputSlots> expectedSubstitutableOutputSlots
1177*89c4ff92SAndroid Build Coastguard Worker     {
1178*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1179*89c4ff92SAndroid Build Coastguard Worker                 ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots())),
1180*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1181*89c4ff92SAndroid Build Coastguard Worker                 ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer")->GetOutputSlots()))
1182*89c4ff92SAndroid Build Coastguard Worker     };
1183*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IConnectableLayers> expectedSubstitutableLayers
1184*89c4ff92SAndroid Build Coastguard Worker     {
1185*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("Weights Layer 1"), layersInGraph.at("Bias Layer 1"), layersInGraph.at("conv1 layer") },
1186*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("Weights Layer 2"), layersInGraph.at("Bias Layer 2"), layersInGraph.at("conv2 layer") }
1187*89c4ff92SAndroid Build Coastguard Worker     };
1188*89c4ff92SAndroid Build Coastguard Worker 
1189*89c4ff92SAndroid Build Coastguard Worker     for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++)
1190*89c4ff92SAndroid Build Coastguard Worker     {
1191*89c4ff92SAndroid Build Coastguard Worker         CheckSubstitution(substitutions.at(substitutionIndex),
1192*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableSubgraphSizes.at(substitutionIndex),
1193*89c4ff92SAndroid Build Coastguard Worker                           expectedReplacementSubgraphSizes.at(substitutionIndex),
1194*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableInputSlots.at(substitutionIndex),
1195*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableOutputSlots.at(substitutionIndex),
1196*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableLayers.at(substitutionIndex));
1197*89c4ff92SAndroid Build Coastguard Worker     }
1198*89c4ff92SAndroid Build Coastguard Worker 
1199*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1200*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
1201*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1202*89c4ff92SAndroid Build Coastguard Worker 
1203*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews::Subgraphs failedSubgraphs = optimizationViews.GetFailedSubgraphs();
1204*89c4ff92SAndroid Build Coastguard Worker     CHECK(failedSubgraphs.size() == 2);
1205*89c4ff92SAndroid Build Coastguard Worker     // Sort into a consistent order
1206*89c4ff92SAndroid Build Coastguard Worker     std::sort(failedSubgraphs.begin(), failedSubgraphs.end(), [](auto s1, auto s2) {
1207*89c4ff92SAndroid Build Coastguard Worker         return strcmp(s1.GetIConnectableLayers().front()->GetName(),
1208*89c4ff92SAndroid Build Coastguard Worker                       s2.GetIConnectableLayers().front()->GetName()) < 0;
1209*89c4ff92SAndroid Build Coastguard Worker     });
1210*89c4ff92SAndroid Build Coastguard Worker 
1211*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedFailedSubgraphSizes{ { 1, 1, 2 },
1212*89c4ff92SAndroid Build Coastguard Worker                                                                    { 1, 1, 1 } };
1213*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IInputSlots> expectedFailedInputSlots
1214*89c4ff92SAndroid Build Coastguard Worker     {
1215*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot, IInputSlot>(
1216*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("pooling1 layer")->GetInputSlots())),
1217*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot, IInputSlot>(
1218*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("pooling3 layer")->GetInputSlots()))
1219*89c4ff92SAndroid Build Coastguard Worker     };
1220*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IOutputSlots> expectedFailedOutputSlots
1221*89c4ff92SAndroid Build Coastguard Worker     {
1222*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1223*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("pooling2 layer")->GetOutputSlots())),
1224*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1225*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("pooling3 layer")->GetOutputSlots()))
1226*89c4ff92SAndroid Build Coastguard Worker     };
1227*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IConnectableLayers> expectedFailedLayers
1228*89c4ff92SAndroid Build Coastguard Worker     {
1229*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("pooling1 layer"),
1230*89c4ff92SAndroid Build Coastguard Worker           layersInGraph.at("pooling2 layer") },
1231*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("pooling3 layer") }
1232*89c4ff92SAndroid Build Coastguard Worker     };
1233*89c4ff92SAndroid Build Coastguard Worker 
1234*89c4ff92SAndroid Build Coastguard Worker     for (size_t failedIndex = 0; failedIndex < failedSubgraphs.size(); failedIndex++)
1235*89c4ff92SAndroid Build Coastguard Worker     {
1236*89c4ff92SAndroid Build Coastguard Worker         CheckFailedSubgraph(failedSubgraphs.at(failedIndex),
1237*89c4ff92SAndroid Build Coastguard Worker                             expectedFailedSubgraphSizes.at(failedIndex),
1238*89c4ff92SAndroid Build Coastguard Worker                             expectedFailedInputSlots.at(failedIndex),
1239*89c4ff92SAndroid Build Coastguard Worker                             expectedFailedOutputSlots.at(failedIndex),
1240*89c4ff92SAndroid Build Coastguard Worker                             expectedFailedLayers.at(failedIndex));
1241*89c4ff92SAndroid Build Coastguard Worker     }
1242*89c4ff92SAndroid Build Coastguard Worker 
1243*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1244*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
1245*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1246*89c4ff92SAndroid Build Coastguard Worker 
1247*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
1248*89c4ff92SAndroid Build Coastguard Worker }
1249*89c4ff92SAndroid Build Coastguard Worker 
1250*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only unoptimizable layers ("unoptimizable" is added to the layer's name)
FullyUnoptimizableSubgraphTestImpl1()1251*89c4ff92SAndroid Build Coastguard Worker void FullyUnoptimizableSubgraphTestImpl1()
1252*89c4ff92SAndroid Build Coastguard Worker {
1253*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
1254*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
1255*89c4ff92SAndroid Build Coastguard Worker 
1256*89c4ff92SAndroid Build Coastguard Worker     // Create a fully optimizable subgraph
1257*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnoptimizableSubgraph1(graph, layersInGraph);
1258*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
1259*89c4ff92SAndroid Build Coastguard Worker 
1260*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1261*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1262*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1263*89c4ff92SAndroid Build Coastguard Worker 
1264*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
1265*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
1266*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 3);
1267*89c4ff92SAndroid Build Coastguard Worker 
1268*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv layer unoptimizable"));
1269*89c4ff92SAndroid Build Coastguard Worker 
1270*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
1271*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
1272*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
1273*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
1274*89c4ff92SAndroid Build Coastguard Worker 
1275*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
1276*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
1277*89c4ff92SAndroid Build Coastguard Worker 
1278*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
1279*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1280*89c4ff92SAndroid Build Coastguard Worker 
1281*89c4ff92SAndroid Build Coastguard Worker     // ============================================================================
1282*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
1283*89c4ff92SAndroid Build Coastguard Worker     //  - No substitutions
1284*89c4ff92SAndroid Build Coastguard Worker     //  - No failed subgraphs
1285*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one untouched subgraph, corresponding to the whole input subgraph
1286*89c4ff92SAndroid Build Coastguard Worker     // ============================================================================
1287*89c4ff92SAndroid Build Coastguard Worker 
1288*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1289*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
1290*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1291*89c4ff92SAndroid Build Coastguard Worker 
1292*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetSubstitutions().empty());
1293*89c4ff92SAndroid Build Coastguard Worker 
1294*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1295*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
1296*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1297*89c4ff92SAndroid Build Coastguard Worker 
1298*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetFailedSubgraphs().empty());
1299*89c4ff92SAndroid Build Coastguard Worker 
1300*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1301*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
1302*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1303*89c4ff92SAndroid Build Coastguard Worker 
1304*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1305*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphs.size() == 1);
1306*89c4ff92SAndroid Build Coastguard Worker 
1307*89c4ff92SAndroid Build Coastguard Worker     CheckUntouchedSubgraph(untouchedSubgraphs.at(0),
1308*89c4ff92SAndroid Build Coastguard Worker                            {subgraphInputSlots.size(),
1309*89c4ff92SAndroid Build Coastguard Worker                             subgraphOutputSlots.size(), subgraphLayers.size()},
1310*89c4ff92SAndroid Build Coastguard Worker                            subgraphInputSlots, subgraphOutputSlots,
1311*89c4ff92SAndroid Build Coastguard Worker                            subgraphLayers);
1312*89c4ff92SAndroid Build Coastguard Worker }
1313*89c4ff92SAndroid Build Coastguard Worker 
1314*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains some unoptimizable layers ("unoptimizable" is added to the layer's name)
PartiallyOptimizableSubgraphTestImpl1()1315*89c4ff92SAndroid Build Coastguard Worker void PartiallyOptimizableSubgraphTestImpl1()
1316*89c4ff92SAndroid Build Coastguard Worker {
1317*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
1318*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
1319*89c4ff92SAndroid Build Coastguard Worker 
1320*89c4ff92SAndroid Build Coastguard Worker     // Create a fully optimizable subgraph
1321*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph1(graph, layersInGraph);
1322*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
1323*89c4ff92SAndroid Build Coastguard Worker 
1324*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1325*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1326*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1327*89c4ff92SAndroid Build Coastguard Worker 
1328*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 1);
1329*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
1330*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 15);
1331*89c4ff92SAndroid Build Coastguard Worker 
1332*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv1 layer"));
1333*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv2 layer unoptimizable"));
1334*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv3 layer"));
1335*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv4 layer unoptimizable"));
1336*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv5 layer"));
1337*89c4ff92SAndroid Build Coastguard Worker 
1338*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
1339*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
1340*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
1341*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
1342*89c4ff92SAndroid Build Coastguard Worker 
1343*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
1344*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
1345*89c4ff92SAndroid Build Coastguard Worker 
1346*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
1347*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1348*89c4ff92SAndroid Build Coastguard Worker 
1349*89c4ff92SAndroid Build Coastguard Worker     // ===============================================================================
1350*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
1351*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly three substitutions, corresponding to the optimizable layers
1352*89c4ff92SAndroid Build Coastguard Worker     //  - No failed subgraphs
1353*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly two untouched subgraphs, corresponding to the non-optimizable layers
1354*89c4ff92SAndroid Build Coastguard Worker     // ===============================================================================
1355*89c4ff92SAndroid Build Coastguard Worker 
1356*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1357*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
1358*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1359*89c4ff92SAndroid Build Coastguard Worker 
1360*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
1361*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutions.size() == 3);
1362*89c4ff92SAndroid Build Coastguard Worker     // Sort into a consistent order
1363*89c4ff92SAndroid Build Coastguard Worker     std::sort(substitutions.begin(), substitutions.end(),
1364*89c4ff92SAndroid Build Coastguard Worker         [](auto s1, auto s2)
1365*89c4ff92SAndroid Build Coastguard Worker         { return strcmp(s1.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName(),
1366*89c4ff92SAndroid Build Coastguard Worker                         s2.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName()) < 0; });
1367*89c4ff92SAndroid Build Coastguard Worker 
1368*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 3 },
1369*89c4ff92SAndroid Build Coastguard Worker                                                                           { 1, 1, 3 },
1370*89c4ff92SAndroid Build Coastguard Worker                                                                           { 1, 1, 3 } };
1371*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 },
1372*89c4ff92SAndroid Build Coastguard Worker                                                                         { 1, 1, 1 },
1373*89c4ff92SAndroid Build Coastguard Worker                                                                         { 1, 1, 1 } };
1374*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IInputSlots> expectedSubstitutableInputSlots
1375*89c4ff92SAndroid Build Coastguard Worker     {
1376*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot, IInputSlot>(
1377*89c4ff92SAndroid Build Coastguard Worker             {ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlot(0))}),
1378*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot, IInputSlot>(
1379*89c4ff92SAndroid Build Coastguard Worker         {ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlot(0))}),
1380*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot, IInputSlot>(
1381*89c4ff92SAndroid Build Coastguard Worker         {ConvertReferenceTypeToPointerType(layersInGraph.at("conv5 layer")->GetInputSlot(0))})
1382*89c4ff92SAndroid Build Coastguard Worker     };
1383*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IOutputSlots> expectedSubstitutableOutputSlots
1384*89c4ff92SAndroid Build Coastguard Worker     {
1385*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1386*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots())),
1387*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1388*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetOutputSlots())),
1389*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1390*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("conv5 layer")->GetOutputSlots()))
1391*89c4ff92SAndroid Build Coastguard Worker     };
1392*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IConnectableLayers> expectedSubstitutableLayers
1393*89c4ff92SAndroid Build Coastguard Worker     {
1394*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("Weights Layer 1"), layersInGraph.at("Bias Layer 1"), layersInGraph.at("conv1 layer") },
1395*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("Weights Layer 3"), layersInGraph.at("Bias Layer 3"), layersInGraph.at("conv3 layer") },
1396*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("Weights Layer 5"), layersInGraph.at("Bias Layer 5"), layersInGraph.at("conv5 layer") }
1397*89c4ff92SAndroid Build Coastguard Worker     };
1398*89c4ff92SAndroid Build Coastguard Worker 
1399*89c4ff92SAndroid Build Coastguard Worker     for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++)
1400*89c4ff92SAndroid Build Coastguard Worker     {
1401*89c4ff92SAndroid Build Coastguard Worker         CheckSubstitution(substitutions.at(substitutionIndex),
1402*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableSubgraphSizes.at(substitutionIndex),
1403*89c4ff92SAndroid Build Coastguard Worker                           expectedReplacementSubgraphSizes.at(substitutionIndex),
1404*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableInputSlots.at(substitutionIndex),
1405*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableOutputSlots.at(substitutionIndex),
1406*89c4ff92SAndroid Build Coastguard Worker                           expectedSubstitutableLayers.at(substitutionIndex));
1407*89c4ff92SAndroid Build Coastguard Worker     }
1408*89c4ff92SAndroid Build Coastguard Worker 
1409*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1410*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
1411*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1412*89c4ff92SAndroid Build Coastguard Worker 
1413*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetFailedSubgraphs().empty());
1414*89c4ff92SAndroid Build Coastguard Worker 
1415*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1416*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
1417*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1418*89c4ff92SAndroid Build Coastguard Worker 
1419*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews::Subgraphs untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1420*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphs.size() == 2);
1421*89c4ff92SAndroid Build Coastguard Worker     // Sort into a consistent order
1422*89c4ff92SAndroid Build Coastguard Worker     std::sort(untouchedSubgraphs.begin(), untouchedSubgraphs.end(), [](auto s1, auto s2) {
1423*89c4ff92SAndroid Build Coastguard Worker         return strcmp(s1.GetIConnectableLayers().front()->GetName(),
1424*89c4ff92SAndroid Build Coastguard Worker                       s2.GetIConnectableLayers().front()->GetName()) < 0;
1425*89c4ff92SAndroid Build Coastguard Worker     });
1426*89c4ff92SAndroid Build Coastguard Worker 
1427*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 3 },
1428*89c4ff92SAndroid Build Coastguard Worker                                                                       { 1, 1, 3 } };
1429*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IInputSlots> expectedUntouchedInputSlots{
1430*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot,
1431*89c4ff92SAndroid Build Coastguard Worker                              IInputSlot>({ConvertReferenceTypeToPointerType(
1432*89c4ff92SAndroid Build Coastguard Worker             layersInGraph.at("conv2 layer unoptimizable")->GetInputSlot(0))}),
1433*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot,
1434*89c4ff92SAndroid Build Coastguard Worker                              IInputSlot>({ConvertReferenceTypeToPointerType(
1435*89c4ff92SAndroid Build Coastguard Worker             layersInGraph.at("conv4 layer unoptimizable")->GetInputSlot(0))})};
1436*89c4ff92SAndroid Build Coastguard Worker 
1437*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IOutputSlots> expectedUntouchedOutputSlots
1438*89c4ff92SAndroid Build Coastguard Worker         {
1439*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1440*89c4ff92SAndroid Build Coastguard Worker                 ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer unoptimizable")->GetOutputSlots())),
1441*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1442*89c4ff92SAndroid Build Coastguard Worker                 ConvertReferenceTypeToPointerType(layersInGraph.at("conv4 layer unoptimizable")->GetOutputSlots()))
1443*89c4ff92SAndroid Build Coastguard Worker         };
1444*89c4ff92SAndroid Build Coastguard Worker 
1445*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IConnectableLayers> expectedUntouchedLayers
1446*89c4ff92SAndroid Build Coastguard Worker         {
1447*89c4ff92SAndroid Build Coastguard Worker             { layersInGraph.at("Weights Layer 2 unoptimizable"),
1448*89c4ff92SAndroid Build Coastguard Worker               layersInGraph.at("Bias Layer 2 unoptimizable"),
1449*89c4ff92SAndroid Build Coastguard Worker               layersInGraph.at("conv2 layer unoptimizable") },
1450*89c4ff92SAndroid Build Coastguard Worker             { layersInGraph.at("Weights Layer 4 unoptimizable"),
1451*89c4ff92SAndroid Build Coastguard Worker               layersInGraph.at("Bias Layer 4 unoptimizable"),
1452*89c4ff92SAndroid Build Coastguard Worker               layersInGraph.at("conv4 layer unoptimizable") }
1453*89c4ff92SAndroid Build Coastguard Worker         };
1454*89c4ff92SAndroid Build Coastguard Worker 
1455*89c4ff92SAndroid Build Coastguard Worker     for (size_t untouchedIndex = 0; untouchedIndex < untouchedSubgraphs.size(); untouchedIndex++)
1456*89c4ff92SAndroid Build Coastguard Worker     {
1457*89c4ff92SAndroid Build Coastguard Worker         CheckUntouchedSubgraph(untouchedSubgraphs.at(untouchedIndex),
1458*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedSubgraphSizes.at(untouchedIndex),
1459*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedInputSlots.at(untouchedIndex),
1460*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedOutputSlots.at(untouchedIndex),
1461*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedLayers.at(untouchedIndex));
1462*89c4ff92SAndroid Build Coastguard Worker     }
1463*89c4ff92SAndroid Build Coastguard Worker }
1464*89c4ff92SAndroid Build Coastguard Worker 
1465*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains some unoptimizable layers ("unoptimizable" is added to the layer's name),
1466*89c4ff92SAndroid Build Coastguard Worker // this is meant to test input slots coming from different layers
PartiallyOptimizableSubgraphTestImpl2()1467*89c4ff92SAndroid Build Coastguard Worker void PartiallyOptimizableSubgraphTestImpl2()
1468*89c4ff92SAndroid Build Coastguard Worker {
1469*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
1470*89c4ff92SAndroid Build Coastguard Worker     LayerNameToLayerMap layersInGraph;
1471*89c4ff92SAndroid Build Coastguard Worker 
1472*89c4ff92SAndroid Build Coastguard Worker     // Create a partially optimizable subgraph
1473*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph2(graph, layersInGraph);
1474*89c4ff92SAndroid Build Coastguard Worker     CHECK((subgraphPtr != nullptr));
1475*89c4ff92SAndroid Build Coastguard Worker 
1476*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1477*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1478*89c4ff92SAndroid Build Coastguard Worker     const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1479*89c4ff92SAndroid Build Coastguard Worker 
1480*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphInputSlots.size()  == 2);
1481*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphOutputSlots.size() == 1);
1482*89c4ff92SAndroid Build Coastguard Worker     CHECK(subgraphLayers.size()      == 10);
1483*89c4ff92SAndroid Build Coastguard Worker 
1484*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv1 layer"));
1485*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv2 layer unoptimizable"));
1486*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "conv3 layer"));
1487*89c4ff92SAndroid Build Coastguard Worker     CHECK(Contains(layersInGraph, "add layer"));
1488*89c4ff92SAndroid Build Coastguard Worker 
1489*89c4ff92SAndroid Build Coastguard Worker     // Create a mock backend object
1490*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser initialiser; // Register the Mock Backend
1491*89c4ff92SAndroid Build Coastguard Worker     auto backendObjPtr = CreateBackendObject(MockBackendId());
1492*89c4ff92SAndroid Build Coastguard Worker     CHECK((backendObjPtr != nullptr));
1493*89c4ff92SAndroid Build Coastguard Worker 
1494*89c4ff92SAndroid Build Coastguard Worker     // Optimize the subgraph
1495*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews;
1496*89c4ff92SAndroid Build Coastguard Worker 
1497*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimization is carried out correctly
1498*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1499*89c4ff92SAndroid Build Coastguard Worker 
1500*89c4ff92SAndroid Build Coastguard Worker     // ==============================================================================
1501*89c4ff92SAndroid Build Coastguard Worker     // The expected results are:
1502*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly one substitution, corresponding to the optimizable layers
1503*89c4ff92SAndroid Build Coastguard Worker     //  - No failed subgraphs
1504*89c4ff92SAndroid Build Coastguard Worker     //  - Exactly two untouched subgraphs, corresponding to the non-optimizable layer
1505*89c4ff92SAndroid Build Coastguard Worker     // ==============================================================================
1506*89c4ff92SAndroid Build Coastguard Worker 
1507*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1508*89c4ff92SAndroid Build Coastguard Worker     // Check the substitutions
1509*89c4ff92SAndroid Build Coastguard Worker     // -----------------------
1510*89c4ff92SAndroid Build Coastguard Worker 
1511*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
1512*89c4ff92SAndroid Build Coastguard Worker     CHECK(substitutions.size() == 1);
1513*89c4ff92SAndroid Build Coastguard Worker 
1514*89c4ff92SAndroid Build Coastguard Worker     ExpectedSubgraphSize expectedSubstitutableSubgraphSizes{ 2, 1, 7 };
1515*89c4ff92SAndroid Build Coastguard Worker     ExpectedSubgraphSize expectedReplacementSubgraphSizes{ 2, 1, 1 };
1516*89c4ff92SAndroid Build Coastguard Worker 
1517*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::IInputSlots expectedSubstitutableInputSlots
1518*89c4ff92SAndroid Build Coastguard Worker     {
1519*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<InputSlot, IInputSlot>({
1520*89c4ff92SAndroid Build Coastguard Worker                     ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()[0])})[0],
1521*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<InputSlot, IInputSlot>({
1522*89c4ff92SAndroid Build Coastguard Worker                     ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots()[0])})[0]
1523*89c4ff92SAndroid Build Coastguard Worker     };
1524*89c4ff92SAndroid Build Coastguard Worker 
1525*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::IOutputSlots expectedSubstitutableOutputSlots
1526*89c4ff92SAndroid Build Coastguard Worker     {
1527*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1528*89c4ff92SAndroid Build Coastguard Worker                     ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots()))
1529*89c4ff92SAndroid Build Coastguard Worker     };
1530*89c4ff92SAndroid Build Coastguard Worker 
1531*89c4ff92SAndroid Build Coastguard Worker     SubgraphView::IConnectableLayers expectedSubstitutableLayers
1532*89c4ff92SAndroid Build Coastguard Worker     {
1533*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("Weights Layer 1"),
1534*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("Weights Layer 3"),
1535*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("Bias Layer 1"),
1536*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("Bias Layer 3"),
1537*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("conv1 layer"),
1538*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("conv3 layer"),
1539*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("add layer")
1540*89c4ff92SAndroid Build Coastguard Worker     };
1541*89c4ff92SAndroid Build Coastguard Worker 
1542*89c4ff92SAndroid Build Coastguard Worker     CheckSubstitution(substitutions[0],
1543*89c4ff92SAndroid Build Coastguard Worker                       expectedSubstitutableSubgraphSizes,
1544*89c4ff92SAndroid Build Coastguard Worker                       expectedReplacementSubgraphSizes,
1545*89c4ff92SAndroid Build Coastguard Worker                       expectedSubstitutableInputSlots,
1546*89c4ff92SAndroid Build Coastguard Worker                       expectedSubstitutableOutputSlots,
1547*89c4ff92SAndroid Build Coastguard Worker                       expectedSubstitutableLayers);
1548*89c4ff92SAndroid Build Coastguard Worker 
1549*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1550*89c4ff92SAndroid Build Coastguard Worker     // Check the failed subgraphs
1551*89c4ff92SAndroid Build Coastguard Worker     // --------------------------
1552*89c4ff92SAndroid Build Coastguard Worker 
1553*89c4ff92SAndroid Build Coastguard Worker     CHECK(optimizationViews.GetFailedSubgraphs().empty());
1554*89c4ff92SAndroid Build Coastguard Worker 
1555*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1556*89c4ff92SAndroid Build Coastguard Worker     // Check the untouched subgraphs
1557*89c4ff92SAndroid Build Coastguard Worker     // -----------------------------
1558*89c4ff92SAndroid Build Coastguard Worker 
1559*89c4ff92SAndroid Build Coastguard Worker     const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1560*89c4ff92SAndroid Build Coastguard Worker     CHECK(untouchedSubgraphs.size() == 1);
1561*89c4ff92SAndroid Build Coastguard Worker 
1562*89c4ff92SAndroid Build Coastguard Worker     std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 3 } };
1563*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IInputSlots> expectedUntouchedInputSlots
1564*89c4ff92SAndroid Build Coastguard Worker     {
1565*89c4ff92SAndroid Build Coastguard Worker         ConvertSlotsToISlots<InputSlot,
1566*89c4ff92SAndroid Build Coastguard Worker                              IInputSlot>({ConvertReferenceTypeToPointerType(
1567*89c4ff92SAndroid Build Coastguard Worker             layersInGraph.at("conv2 layer unoptimizable")->GetInputSlot(0))})};
1568*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IOutputSlots> expectedUntouchedOutputSlots
1569*89c4ff92SAndroid Build Coastguard Worker     {
1570*89c4ff92SAndroid Build Coastguard Worker             ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1571*89c4ff92SAndroid Build Coastguard Worker         ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer unoptimizable")->GetOutputSlots()))
1572*89c4ff92SAndroid Build Coastguard Worker     };
1573*89c4ff92SAndroid Build Coastguard Worker     std::vector<SubgraphView::IConnectableLayers> expectedUntouchedLayers
1574*89c4ff92SAndroid Build Coastguard Worker     {
1575*89c4ff92SAndroid Build Coastguard Worker         { layersInGraph.at("conv2 layer unoptimizable"), layersInGraph.at("Weights Layer 2 unoptimizable"),
1576*89c4ff92SAndroid Build Coastguard Worker         layersInGraph.at("Bias Layer 2 unoptimizable") }
1577*89c4ff92SAndroid Build Coastguard Worker     };
1578*89c4ff92SAndroid Build Coastguard Worker 
1579*89c4ff92SAndroid Build Coastguard Worker     for (size_t untouchedIndex = 0; untouchedIndex < untouchedSubgraphs.size(); untouchedIndex++)
1580*89c4ff92SAndroid Build Coastguard Worker     {
1581*89c4ff92SAndroid Build Coastguard Worker         CheckUntouchedSubgraph(untouchedSubgraphs.at(untouchedIndex),
1582*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedSubgraphSizes.at(untouchedIndex),
1583*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedInputSlots.at(untouchedIndex),
1584*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedOutputSlots.at(untouchedIndex),
1585*89c4ff92SAndroid Build Coastguard Worker                                expectedUntouchedLayers.at(untouchedIndex));
1586*89c4ff92SAndroid Build Coastguard Worker     }
1587*89c4ff92SAndroid Build Coastguard Worker }
1588*89c4ff92SAndroid Build Coastguard Worker 
1589*89c4ff92SAndroid Build Coastguard Worker } // Anonymous namespace
1590*89c4ff92SAndroid Build Coastguard Worker 
1591*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OptimizeSubGraph")
1592*89c4ff92SAndroid Build Coastguard Worker {
1593*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnsupportedSubgraph1")     { FullyUnsupporteSubgraphTestImpl1();      }
1594*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnsupportedSubgraph2")     { FullyUnsupporteSubgraphTestImpl2();      }
1595*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyOptimizableSubgraph1")     { FullyOptimizableSubgraphTestImpl1();     }
1596*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyOptimizableSubgraph2")     { FullyOptimizableSubgraphTestImpl2();     }
1597*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallySupportedSubgraph")    { PartiallySupportedSubgraphTestImpl();    }
1598*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnoptimizableSubgraph")    { FullyUnoptimizableSubgraphTestImpl1();   }
1599*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallyOptimizableSubgraph1") { PartiallyOptimizableSubgraphTestImpl1(); }
1600*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallyOptimizableSubgraph2") { PartiallyOptimizableSubgraphTestImpl2(); }
1601*89c4ff92SAndroid Build Coastguard Worker 
1602*89c4ff92SAndroid Build Coastguard Worker }
1603