1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017, 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker #include "MockBackendId.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockBackend.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <unordered_map>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker namespace
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker // The expected number of layers, input and output slots in a subgraph after a test
24*89c4ff92SAndroid Build Coastguard Worker struct ExpectedSubgraphSize
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker size_t m_NumInputSlots = 0;
27*89c4ff92SAndroid Build Coastguard Worker size_t m_NumOutputSlots = 0;
28*89c4ff92SAndroid Build Coastguard Worker size_t m_NumLayers = 0;
29*89c4ff92SAndroid Build Coastguard Worker };
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker // Keep the layers organized by layer name
32*89c4ff92SAndroid Build Coastguard Worker using LayerNameToLayerMap = std::unordered_map<std::string, Layer*>;
33*89c4ff92SAndroid Build Coastguard Worker
34*89c4ff92SAndroid Build Coastguard Worker // Used to convert input and output slots from reference type (as stored in graphs) to
35*89c4ff92SAndroid Build Coastguard Worker // pointer type (as stored in subgraphs)
36*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType>
ConvertReferenceTypeToPointerType(const SlotType & input)37*89c4ff92SAndroid Build Coastguard Worker SlotType* ConvertReferenceTypeToPointerType(const SlotType& input)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker return const_cast<SlotType*>(&input);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker // Used to convert input and output slots from reference type (as stored in graphs) to
43*89c4ff92SAndroid Build Coastguard Worker // pointer type (as stored in subgraphs), array version
44*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType>
ConvertReferenceTypeToPointerType(const std::vector<SlotType> & input)45*89c4ff92SAndroid Build Coastguard Worker std::vector<SlotType*> ConvertReferenceTypeToPointerType(const std::vector<SlotType>& input)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker std::vector<SlotType*> output;
48*89c4ff92SAndroid Build Coastguard Worker std::transform(input.begin(),
49*89c4ff92SAndroid Build Coastguard Worker input.end(),
50*89c4ff92SAndroid Build Coastguard Worker std::back_inserter(output),
51*89c4ff92SAndroid Build Coastguard Worker [](const SlotType& inputItem)
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker return ConvertReferenceTypeToPointerType(inputItem);
54*89c4ff92SAndroid Build Coastguard Worker });
55*89c4ff92SAndroid Build Coastguard Worker
56*89c4ff92SAndroid Build Coastguard Worker return output;
57*89c4ff92SAndroid Build Coastguard Worker }
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker // Convert from vector of Slots* (Input/Output) to vector of ISlots* (IInput/IOutput)
60*89c4ff92SAndroid Build Coastguard Worker template <typename SlotType, typename ResultSlotType>
ConvertSlotsToISlots(const std::vector<SlotType * > input)61*89c4ff92SAndroid Build Coastguard Worker std::vector<ResultSlotType*> ConvertSlotsToISlots(const std::vector<SlotType*> input)
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker std::vector<ResultSlotType*> output;
64*89c4ff92SAndroid Build Coastguard Worker for (auto slot : input)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker output.push_back(PolymorphicDowncast<ResultSlotType*>(slot));
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker return output;
69*89c4ff92SAndroid Build Coastguard Worker }
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an input layer to a graph
AddInputLayer(Graph & graph,const std::string & layerName,const TensorInfo & inputInfo,LayerBindingId inputId=0)72*89c4ff92SAndroid Build Coastguard Worker Layer* AddInputLayer(Graph& graph,
73*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName,
74*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo,
75*89c4ff92SAndroid Build Coastguard Worker LayerBindingId inputId = 0)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = graph.AddLayer<InputLayer>(inputId, layerName.c_str());
78*89c4ff92SAndroid Build Coastguard Worker CHECK(inputLayer);
79*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).SetTensorInfo(inputInfo);
80*89c4ff92SAndroid Build Coastguard Worker return inputLayer;
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
83*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an output layer to a graph
AddOutputLayer(Graph & graph,const std::string & layerName)84*89c4ff92SAndroid Build Coastguard Worker Layer* AddOutputLayer(Graph& graph,
85*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = graph.AddLayer<OutputLayer>(0, layerName.c_str());
88*89c4ff92SAndroid Build Coastguard Worker CHECK(outputLayer);
89*89c4ff92SAndroid Build Coastguard Worker return outputLayer;
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
92*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a convolution layer to a graph
AddConvolutionLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const Convolution2dDescriptor & convolutionDescriptor,const std::string & layerName,const TensorInfo & outputInfo)93*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* AddConvolutionLayer(Graph& graph,
94*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap& layersInGraph,
95*89c4ff92SAndroid Build Coastguard Worker const Convolution2dDescriptor& convolutionDescriptor,
96*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName,
97*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const convLayer = graph.AddLayer<Convolution2dLayer>(convolutionDescriptor, layerName.c_str());
100*89c4ff92SAndroid Build Coastguard Worker CHECK(convLayer);
101*89c4ff92SAndroid Build Coastguard Worker convLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
102*89c4ff92SAndroid Build Coastguard Worker layersInGraph.insert(std::make_pair(convLayer->GetName(), convLayer));
103*89c4ff92SAndroid Build Coastguard Worker return convLayer;
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker
106*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a constant layer to a graph
AddConstantLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const std::string & layerName,const ConstTensor & constTensor,const TensorInfo & outputInfo)107*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* AddConstantLayer(Graph& graph,
108*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap& layersInGraph,
109*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName,
110*89c4ff92SAndroid Build Coastguard Worker const ConstTensor& constTensor,
111*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo)
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const constantLayer = graph.AddLayer<ConstantLayer>(layerName.c_str());
114*89c4ff92SAndroid Build Coastguard Worker CHECK(constantLayer);
115*89c4ff92SAndroid Build Coastguard Worker constantLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor);
116*89c4ff92SAndroid Build Coastguard Worker constantLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
117*89c4ff92SAndroid Build Coastguard Worker layersInGraph.insert(std::make_pair(constantLayer->GetName(), constantLayer));
118*89c4ff92SAndroid Build Coastguard Worker return constantLayer;
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add a pooling layer to a graph
AddPoolingLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const Pooling2dDescriptor & poolingDescriptor,const std::string & layerName,const TensorInfo & outputInfo)122*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* AddPoolingLayer(Graph& graph,
123*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap& layersInGraph,
124*89c4ff92SAndroid Build Coastguard Worker const Pooling2dDescriptor& poolingDescriptor,
125*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName,
126*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo)
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const poolingLayer = graph.AddLayer<Pooling2dLayer>(poolingDescriptor, layerName.c_str());
129*89c4ff92SAndroid Build Coastguard Worker CHECK(poolingLayer);
130*89c4ff92SAndroid Build Coastguard Worker poolingLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
131*89c4ff92SAndroid Build Coastguard Worker layersInGraph.insert(std::make_pair(poolingLayer->GetName(), poolingLayer));
132*89c4ff92SAndroid Build Coastguard Worker return poolingLayer;
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker // Convenience function to add an addition layer to a graph
AddAdditionaLayer(Graph & graph,LayerNameToLayerMap & layersInGraph,const std::string & layerName,const TensorInfo & outputInfo)136*89c4ff92SAndroid Build Coastguard Worker AdditionLayer* AddAdditionaLayer(Graph& graph,
137*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap& layersInGraph,
138*89c4ff92SAndroid Build Coastguard Worker const std::string& layerName,
139*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker AdditionLayer* const additionLayer = graph.AddLayer<AdditionLayer>(layerName.c_str());
142*89c4ff92SAndroid Build Coastguard Worker CHECK(additionLayer);
143*89c4ff92SAndroid Build Coastguard Worker additionLayer->GetOutputSlot(0).SetTensorInfo(outputInfo);
144*89c4ff92SAndroid Build Coastguard Worker layersInGraph.insert(std::make_pair(additionLayer->GetName(), additionLayer));
145*89c4ff92SAndroid Build Coastguard Worker return additionLayer;
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker
148*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given substitution matches the specified expected values
CheckSubstitution(const OptimizationViews::SubstitutionPair & substitution,const ExpectedSubgraphSize & expectedSubstitutableSubgraphSize,const ExpectedSubgraphSize & expectedReplacementSubgraphSize,const SubgraphView::IInputSlots & expectedSubstitutableInputSlots,const SubgraphView::IOutputSlots & expectedSubstitutableOutputSlots,const SubgraphView::IConnectableLayers & expectedSubstitutableLayers)149*89c4ff92SAndroid Build Coastguard Worker void CheckSubstitution(const OptimizationViews::SubstitutionPair& substitution,
150*89c4ff92SAndroid Build Coastguard Worker const ExpectedSubgraphSize& expectedSubstitutableSubgraphSize,
151*89c4ff92SAndroid Build Coastguard Worker const ExpectedSubgraphSize& expectedReplacementSubgraphSize,
152*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& expectedSubstitutableInputSlots,
153*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& expectedSubstitutableOutputSlots,
154*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& expectedSubstitutableLayers)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker const SubgraphView& substitutableSubgraph = substitution.m_SubstitutableSubgraph;
157*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& substitutableSubgraphInputSlots = substitutableSubgraph.GetIInputSlots();
158*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& substitutableSubgraphOutputSlots = substitutableSubgraph.GetIOutputSlots();
159*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& substitutableSubgraphLayers =
160*89c4ff92SAndroid Build Coastguard Worker substitutableSubgraph.GetIConnectableLayers();
161*89c4ff92SAndroid Build Coastguard Worker
162*89c4ff92SAndroid Build Coastguard Worker const SubgraphView& replacementSubgraph = substitution.m_ReplacementSubgraph;
163*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& replacementSubgraphInputSlots = replacementSubgraph.GetIInputSlots();
164*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& replacementSubgraphOutputSlots = replacementSubgraph.GetIOutputSlots();
165*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& replacementSubgraphLayers = replacementSubgraph.GetIConnectableLayers();
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutableSubgraphInputSlots.size() == expectedSubstitutableSubgraphSize.m_NumInputSlots);
168*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutableSubgraphOutputSlots.size() == expectedSubstitutableSubgraphSize.m_NumOutputSlots);
169*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutableSubgraphLayers.size() == expectedSubstitutableSubgraphSize.m_NumLayers);
170*89c4ff92SAndroid Build Coastguard Worker
171*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(substitutableSubgraphInputSlots, expectedSubstitutableInputSlots));
172*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(substitutableSubgraphOutputSlots, expectedSubstitutableOutputSlots));
173*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(substitutableSubgraphLayers, expectedSubstitutableLayers));
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker CHECK(replacementSubgraphInputSlots.size() == expectedReplacementSubgraphSize.m_NumInputSlots);
176*89c4ff92SAndroid Build Coastguard Worker CHECK(replacementSubgraphOutputSlots.size() == expectedReplacementSubgraphSize.m_NumOutputSlots);
177*89c4ff92SAndroid Build Coastguard Worker CHECK(replacementSubgraphLayers.size() == expectedReplacementSubgraphSize.m_NumLayers);
178*89c4ff92SAndroid Build Coastguard Worker
179*89c4ff92SAndroid Build Coastguard Worker CHECK(!AreEqual(replacementSubgraphInputSlots, expectedSubstitutableInputSlots));
180*89c4ff92SAndroid Build Coastguard Worker CHECK(!AreEqual(replacementSubgraphOutputSlots, expectedSubstitutableOutputSlots));
181*89c4ff92SAndroid Build Coastguard Worker CHECK(!AreEqual(replacementSubgraphLayers, expectedSubstitutableLayers));
182*89c4ff92SAndroid Build Coastguard Worker
183*89c4ff92SAndroid Build Coastguard Worker CHECK(std::all_of(replacementSubgraphLayers.begin(),
184*89c4ff92SAndroid Build Coastguard Worker replacementSubgraphLayers.end(),
185*89c4ff92SAndroid Build Coastguard Worker [](const IConnectableLayer* layer)
186*89c4ff92SAndroid Build Coastguard Worker {
187*89c4ff92SAndroid Build Coastguard Worker return layer->GetType() == LayerType::PreCompiled;
188*89c4ff92SAndroid Build Coastguard Worker }));
189*89c4ff92SAndroid Build Coastguard Worker }
190*89c4ff92SAndroid Build Coastguard Worker
191*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given failed subgraph matches the specified expected values
CheckFailedSubgraph(const SubgraphView & failedSubgraph,const ExpectedSubgraphSize & expectedFailedSubgraphSize,const SubgraphView::IInputSlots & expectedFailedInputSlots,const SubgraphView::IOutputSlots & expectedFailedOutputSlots,const SubgraphView::IConnectableLayers & expectedFailedLayers)192*89c4ff92SAndroid Build Coastguard Worker void CheckFailedSubgraph(const SubgraphView& failedSubgraph,
193*89c4ff92SAndroid Build Coastguard Worker const ExpectedSubgraphSize& expectedFailedSubgraphSize,
194*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& expectedFailedInputSlots,
195*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& expectedFailedOutputSlots,
196*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& expectedFailedLayers)
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& failedSubgraphInputSlots = failedSubgraph.GetIInputSlots();
199*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& failedSubgraphOutputSlots = failedSubgraph.GetIOutputSlots();
200*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& failedSubgraphLayers = failedSubgraph.GetIConnectableLayers();
201*89c4ff92SAndroid Build Coastguard Worker
202*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphInputSlots.size() == expectedFailedSubgraphSize.m_NumInputSlots);
203*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphOutputSlots.size() == expectedFailedSubgraphSize.m_NumOutputSlots);
204*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphLayers.size() == expectedFailedSubgraphSize.m_NumLayers);
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(failedSubgraphInputSlots, expectedFailedInputSlots));
207*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(failedSubgraphOutputSlots, expectedFailedOutputSlots));
208*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(failedSubgraphLayers, expectedFailedLayers));
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker
211*89c4ff92SAndroid Build Coastguard Worker // Convenience function to check that the given untouched subgraph matches the specified expected values
CheckUntouchedSubgraph(const SubgraphView & untouchedSubgraph,const ExpectedSubgraphSize & expectedUntouchedSubgraphSize,const SubgraphView::IInputSlots & expectedUntouchedInputSlots,const SubgraphView::IOutputSlots & expectedUntouchedOutputSlots,const SubgraphView::IConnectableLayers & expectedUntouchedLayers)212*89c4ff92SAndroid Build Coastguard Worker void CheckUntouchedSubgraph(const SubgraphView& untouchedSubgraph,
213*89c4ff92SAndroid Build Coastguard Worker const ExpectedSubgraphSize& expectedUntouchedSubgraphSize,
214*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& expectedUntouchedInputSlots,
215*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& expectedUntouchedOutputSlots,
216*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& expectedUntouchedLayers)
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& untouchedSubgraphInputSlots = untouchedSubgraph.GetIInputSlots();
219*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& untouchedSubgraphOutputSlots = untouchedSubgraph.GetIOutputSlots();
220*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& untouchedSubgraphLayers = untouchedSubgraph.GetIConnectableLayers();
221*89c4ff92SAndroid Build Coastguard Worker
222*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphInputSlots.size() == expectedUntouchedSubgraphSize.m_NumInputSlots);
223*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphOutputSlots.size() == expectedUntouchedSubgraphSize.m_NumOutputSlots);
224*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphLayers.size() == expectedUntouchedSubgraphSize.m_NumLayers);
225*89c4ff92SAndroid Build Coastguard Worker
226*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(untouchedSubgraphInputSlots, expectedUntouchedInputSlots));
227*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(untouchedSubgraphOutputSlots, expectedUntouchedOutputSlots));
228*89c4ff92SAndroid Build Coastguard Worker CHECK(AreEqual(untouchedSubgraphLayers, expectedUntouchedLayers));
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker
231*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph containing only a single unsupported layer (only convolutions are unsupported by the mock backend)
BuildFullyUnsupportedSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)232*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnsupportedSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
233*89c4ff92SAndroid Build Coastguard Worker {
234*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
235*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
236*89c4ff92SAndroid Build Coastguard Worker
237*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor poolingDescriptor;
238*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolType = armnn::PoolingAlgorithm::Average;
239*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolWidth = 2;
240*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolHeight = 2;
241*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideX = 2;
242*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideY = 2;
243*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadLeft = 1;
244*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadRight = 1;
245*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadTop = 1;
246*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadBottom = 1;
247*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
248*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_DataLayout = DataLayout::NHWC;
249*89c4ff92SAndroid Build Coastguard Worker
250*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
251*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
252*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const poolingLayer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
253*89c4ff92SAndroid Build Coastguard Worker "pooling layer", outputInfo);
254*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker // Connect the network
257*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(poolingLayer->GetInputSlot(0));
258*89c4ff92SAndroid Build Coastguard Worker poolingLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
259*89c4ff92SAndroid Build Coastguard Worker
260*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
261*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(poolingLayer),
262*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({poolingLayer}),
263*89c4ff92SAndroid Build Coastguard Worker {poolingLayer});
264*89c4ff92SAndroid Build Coastguard Worker }
265*89c4ff92SAndroid Build Coastguard Worker
266*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph containing only unsupported layers (only convolutions are unsupported by the mock backend)
BuildFullyUnsupportedSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)267*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnsupportedSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
270*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
271*89c4ff92SAndroid Build Coastguard Worker
272*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor poolingDescriptor;
273*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolType = armnn::PoolingAlgorithm::Average;
274*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolWidth = 2;
275*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolHeight = 2;
276*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideX = 2;
277*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideY = 2;
278*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadLeft = 1;
279*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadRight = 1;
280*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadTop = 1;
281*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadBottom = 1;
282*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
283*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_DataLayout = DataLayout::NHWC;
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
286*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
287*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling1Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
288*89c4ff92SAndroid Build Coastguard Worker "pooling1 layer", outputInfo);
289*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling2Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
290*89c4ff92SAndroid Build Coastguard Worker "pooling2 layer", outputInfo);
291*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling3Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
292*89c4ff92SAndroid Build Coastguard Worker "pooling3 layer", outputInfo);
293*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
294*89c4ff92SAndroid Build Coastguard Worker
295*89c4ff92SAndroid Build Coastguard Worker // Connect the network
296*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(pooling1Layer->GetInputSlot(0));
297*89c4ff92SAndroid Build Coastguard Worker pooling1Layer->GetOutputSlot(0).Connect(pooling2Layer->GetInputSlot(0));
298*89c4ff92SAndroid Build Coastguard Worker pooling2Layer->GetOutputSlot(0).Connect(pooling3Layer->GetInputSlot(0));
299*89c4ff92SAndroid Build Coastguard Worker pooling3Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
300*89c4ff92SAndroid Build Coastguard Worker
301*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
302*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(pooling1Layer),
303*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({pooling3Layer}),
304*89c4ff92SAndroid Build Coastguard Worker {pooling1Layer,
305*89c4ff92SAndroid Build Coastguard Worker pooling2Layer,
306*89c4ff92SAndroid Build Coastguard Worker pooling3Layer});
307*89c4ff92SAndroid Build Coastguard Worker }
308*89c4ff92SAndroid Build Coastguard Worker
309*89c4ff92SAndroid Build Coastguard Worker // Creates a simple subgraph with only one convolution layer, supported by the mock backend
BuildFullyOptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)310*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyOptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
313*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
314*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
315*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
318*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
319*89c4ff92SAndroid Build Coastguard Worker
320*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
321*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
322*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
323*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
324*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
327*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
328*89c4ff92SAndroid Build Coastguard Worker
329*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
330*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
331*89c4ff92SAndroid Build Coastguard Worker
332*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
333*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
334*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const convLayer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
335*89c4ff92SAndroid Build Coastguard Worker "conv layer", outputInfo);
336*89c4ff92SAndroid Build Coastguard Worker
337*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer =
338*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer", constWeightsTensor, weightInfo);
339*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer = AddConstantLayer(graph, layersInGraph, "Bias Layer", constBiasTensor, biasInfo);
340*89c4ff92SAndroid Build Coastguard Worker
341*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
342*89c4ff92SAndroid Build Coastguard Worker
343*89c4ff92SAndroid Build Coastguard Worker // Connect the network
344*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
345*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
346*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
347*89c4ff92SAndroid Build Coastguard Worker convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
348*89c4ff92SAndroid Build Coastguard Worker
349*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
350*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
351*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(convLayer, ignoreSlots),
352*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer}),
353*89c4ff92SAndroid Build Coastguard Worker {convLayer, weightsLayer, biasLayer});
354*89c4ff92SAndroid Build Coastguard Worker }
355*89c4ff92SAndroid Build Coastguard Worker
356*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with five convolutions layers, all supported by the mock backend
BuildFullyOptimizableSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)357*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyOptimizableSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
358*89c4ff92SAndroid Build Coastguard Worker {
359*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
360*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
361*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
362*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo ({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
363*89c4ff92SAndroid Build Coastguard Worker
364*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
365*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
366*89c4ff92SAndroid Build Coastguard Worker
367*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
368*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
369*89c4ff92SAndroid Build Coastguard Worker
370*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
371*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
372*89c4ff92SAndroid Build Coastguard Worker
373*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
374*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
375*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
376*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
377*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
378*89c4ff92SAndroid Build Coastguard Worker
379*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
380*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
381*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
382*89c4ff92SAndroid Build Coastguard Worker "conv1 layer", outputInfo);
383*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer1 =
384*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
385*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer1 =
386*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
387*89c4ff92SAndroid Build Coastguard Worker
388*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
389*89c4ff92SAndroid Build Coastguard Worker "conv2 layer", outputInfo);
390*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer2 =
391*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 2", constWeightsTensor, weightInfo);
392*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer2 =
393*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 2", constBiasTensor, biasInfo);
394*89c4ff92SAndroid Build Coastguard Worker
395*89c4ff92SAndroid Build Coastguard Worker
396*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
397*89c4ff92SAndroid Build Coastguard Worker "conv3 layer", outputInfo);
398*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer3 =
399*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
400*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer3 =
401*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
402*89c4ff92SAndroid Build Coastguard Worker
403*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv4Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
404*89c4ff92SAndroid Build Coastguard Worker "conv4 layer", outputInfo);
405*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer4 =
406*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 4", constWeightsTensor, weightInfo);
407*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer4 =
408*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 4", constBiasTensor, biasInfo);
409*89c4ff92SAndroid Build Coastguard Worker
410*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv5Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
411*89c4ff92SAndroid Build Coastguard Worker "conv5 layer", outputInfo);
412*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer5 =
413*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 5", constWeightsTensor, weightInfo);
414*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer5 =
415*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 5", constBiasTensor, biasInfo);
416*89c4ff92SAndroid Build Coastguard Worker
417*89c4ff92SAndroid Build Coastguard Worker
418*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
419*89c4ff92SAndroid Build Coastguard Worker
420*89c4ff92SAndroid Build Coastguard Worker // Connect the network
421*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
422*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
423*89c4ff92SAndroid Build Coastguard Worker biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
424*89c4ff92SAndroid Build Coastguard Worker
425*89c4ff92SAndroid Build Coastguard Worker conv1Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
426*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
427*89c4ff92SAndroid Build Coastguard Worker biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
428*89c4ff92SAndroid Build Coastguard Worker
429*89c4ff92SAndroid Build Coastguard Worker conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
430*89c4ff92SAndroid Build Coastguard Worker weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
431*89c4ff92SAndroid Build Coastguard Worker biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
432*89c4ff92SAndroid Build Coastguard Worker
433*89c4ff92SAndroid Build Coastguard Worker conv3Layer->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(0));
434*89c4ff92SAndroid Build Coastguard Worker weightsLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(1));
435*89c4ff92SAndroid Build Coastguard Worker biasLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(2));
436*89c4ff92SAndroid Build Coastguard Worker
437*89c4ff92SAndroid Build Coastguard Worker conv4Layer->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(0));
438*89c4ff92SAndroid Build Coastguard Worker weightsLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(1));
439*89c4ff92SAndroid Build Coastguard Worker biasLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(2));
440*89c4ff92SAndroid Build Coastguard Worker
441*89c4ff92SAndroid Build Coastguard Worker conv5Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
442*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
443*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
444*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
445*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({ conv5Layer }),
446*89c4ff92SAndroid Build Coastguard Worker { weightsLayer1,
447*89c4ff92SAndroid Build Coastguard Worker biasLayer1,
448*89c4ff92SAndroid Build Coastguard Worker conv1Layer,
449*89c4ff92SAndroid Build Coastguard Worker weightsLayer2,
450*89c4ff92SAndroid Build Coastguard Worker biasLayer2,
451*89c4ff92SAndroid Build Coastguard Worker conv2Layer,
452*89c4ff92SAndroid Build Coastguard Worker weightsLayer3,
453*89c4ff92SAndroid Build Coastguard Worker biasLayer3,
454*89c4ff92SAndroid Build Coastguard Worker conv3Layer,
455*89c4ff92SAndroid Build Coastguard Worker weightsLayer4,
456*89c4ff92SAndroid Build Coastguard Worker biasLayer4,
457*89c4ff92SAndroid Build Coastguard Worker conv4Layer,
458*89c4ff92SAndroid Build Coastguard Worker weightsLayer5,
459*89c4ff92SAndroid Build Coastguard Worker biasLayer5,
460*89c4ff92SAndroid Build Coastguard Worker conv5Layer });
461*89c4ff92SAndroid Build Coastguard Worker }
462*89c4ff92SAndroid Build Coastguard Worker
463*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with both supported and unsupported layers
464*89c4ff92SAndroid Build Coastguard Worker // (only convolutions are unsupported by the mock backend)
BuildPartiallySupportedSubgraph(Graph & graph,LayerNameToLayerMap & layersInGraph)465*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallySupportedSubgraph(Graph& graph, LayerNameToLayerMap& layersInGraph)
466*89c4ff92SAndroid Build Coastguard Worker {
467*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
468*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
469*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
470*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo ({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
471*89c4ff92SAndroid Build Coastguard Worker
472*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
473*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
474*89c4ff92SAndroid Build Coastguard Worker
475*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
476*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
477*89c4ff92SAndroid Build Coastguard Worker
478*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
479*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
480*89c4ff92SAndroid Build Coastguard Worker
481*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
482*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
483*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
484*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
485*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
486*89c4ff92SAndroid Build Coastguard Worker
487*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor poolingDescriptor;
488*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolType = armnn::PoolingAlgorithm::Average;
489*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolWidth = 2;
490*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PoolHeight = 2;
491*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideX = 2;
492*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_StrideY = 2;
493*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadLeft = 1;
494*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadRight = 1;
495*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadTop = 1;
496*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PadBottom = 1;
497*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_PaddingMethod = armnn::PaddingMethod::Exclude;
498*89c4ff92SAndroid Build Coastguard Worker poolingDescriptor.m_DataLayout = DataLayout::NHWC;
499*89c4ff92SAndroid Build Coastguard Worker
500*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
501*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
502*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer1 =
503*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
504*89c4ff92SAndroid Build Coastguard Worker
505*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer1 =
506*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
507*89c4ff92SAndroid Build Coastguard Worker
508*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
509*89c4ff92SAndroid Build Coastguard Worker "conv1 layer", outputInfo);
510*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling1Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
511*89c4ff92SAndroid Build Coastguard Worker "pooling1 layer", outputInfo);
512*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling2Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
513*89c4ff92SAndroid Build Coastguard Worker "pooling2 layer", outputInfo);
514*89c4ff92SAndroid Build Coastguard Worker
515*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer2 =
516*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 2", constWeightsTensor, weightInfo);
517*89c4ff92SAndroid Build Coastguard Worker
518*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer2 =
519*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 2", constBiasTensor, biasInfo);
520*89c4ff92SAndroid Build Coastguard Worker
521*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
522*89c4ff92SAndroid Build Coastguard Worker "conv2 layer", outputInfo);
523*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* const pooling3Layer = AddPoolingLayer(graph, layersInGraph, poolingDescriptor,
524*89c4ff92SAndroid Build Coastguard Worker "pooling3 layer", outputInfo);
525*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
526*89c4ff92SAndroid Build Coastguard Worker
527*89c4ff92SAndroid Build Coastguard Worker // Connect the network
528*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
529*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
530*89c4ff92SAndroid Build Coastguard Worker biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
531*89c4ff92SAndroid Build Coastguard Worker conv1Layer->GetOutputSlot(0).Connect(pooling1Layer->GetInputSlot(0));
532*89c4ff92SAndroid Build Coastguard Worker pooling1Layer->GetOutputSlot(0).Connect(pooling2Layer->GetInputSlot(0));
533*89c4ff92SAndroid Build Coastguard Worker pooling2Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
534*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
535*89c4ff92SAndroid Build Coastguard Worker biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
536*89c4ff92SAndroid Build Coastguard Worker conv2Layer->GetOutputSlot(0).Connect(pooling3Layer->GetInputSlot(0));
537*89c4ff92SAndroid Build Coastguard Worker pooling3Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
538*89c4ff92SAndroid Build Coastguard Worker
539*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
540*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
541*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
542*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({pooling3Layer}),
543*89c4ff92SAndroid Build Coastguard Worker {weightsLayer1,
544*89c4ff92SAndroid Build Coastguard Worker biasLayer1,
545*89c4ff92SAndroid Build Coastguard Worker conv1Layer,
546*89c4ff92SAndroid Build Coastguard Worker pooling1Layer,
547*89c4ff92SAndroid Build Coastguard Worker pooling2Layer,
548*89c4ff92SAndroid Build Coastguard Worker weightsLayer2,
549*89c4ff92SAndroid Build Coastguard Worker biasLayer2,
550*89c4ff92SAndroid Build Coastguard Worker conv2Layer,
551*89c4ff92SAndroid Build Coastguard Worker pooling3Layer});
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker
554*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with only unoptimizable layers ("unoptimizable" is added to the layer's name)
BuildFullyUnoptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)555*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildFullyUnoptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
556*89c4ff92SAndroid Build Coastguard Worker {
557*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
558*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
559*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
560*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo ({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
561*89c4ff92SAndroid Build Coastguard Worker
562*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
563*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
564*89c4ff92SAndroid Build Coastguard Worker
565*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
566*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
567*89c4ff92SAndroid Build Coastguard Worker
568*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
569*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
570*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
571*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
572*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
573*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
574*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
575*89c4ff92SAndroid Build Coastguard Worker
576*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
577*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
578*89c4ff92SAndroid Build Coastguard Worker
579*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer =
580*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer unoptimizable", constWeightsTensor, weightInfo);
581*89c4ff92SAndroid Build Coastguard Worker
582*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer =
583*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer unoptimizable", constBiasTensor, biasInfo);
584*89c4ff92SAndroid Build Coastguard Worker
585*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const convLayer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
586*89c4ff92SAndroid Build Coastguard Worker "conv layer unoptimizable", outputInfo);
587*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
588*89c4ff92SAndroid Build Coastguard Worker
589*89c4ff92SAndroid Build Coastguard Worker // Connect the network
590*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(0));
591*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(1));
592*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(convLayer->GetInputSlot(2));
593*89c4ff92SAndroid Build Coastguard Worker convLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
594*89c4ff92SAndroid Build Coastguard Worker
595*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
596*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
597*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(convLayer, ignoreSlots),
598*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({convLayer}),
599*89c4ff92SAndroid Build Coastguard Worker {convLayer, weightsLayer, biasLayer});
600*89c4ff92SAndroid Build Coastguard Worker }
601*89c4ff92SAndroid Build Coastguard Worker
602*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with some unoptimizable layers ("unoptimizable" is added to the layer's name)
BuildPartiallyOptimizableSubgraph1(Graph & graph,LayerNameToLayerMap & layersInGraph)603*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallyOptimizableSubgraph1(Graph& graph, LayerNameToLayerMap& layersInGraph)
604*89c4ff92SAndroid Build Coastguard Worker {
605*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
606*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
607*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
608*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo ({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
609*89c4ff92SAndroid Build Coastguard Worker
610*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
611*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
612*89c4ff92SAndroid Build Coastguard Worker
613*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
614*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
615*89c4ff92SAndroid Build Coastguard Worker
616*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
617*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
618*89c4ff92SAndroid Build Coastguard Worker
619*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
620*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
621*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
622*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
623*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
624*89c4ff92SAndroid Build Coastguard Worker
625*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
626*89c4ff92SAndroid Build Coastguard Worker Layer* const inputLayer = AddInputLayer(graph, "input layer", inputInfo);
627*89c4ff92SAndroid Build Coastguard Worker
628*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer1 =
629*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
630*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer1 =
631*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
632*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer2 =
633*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 2 unoptimizable", constWeightsTensor, weightInfo);
634*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer2 =
635*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 2 unoptimizable", constBiasTensor, biasInfo);
636*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer3 =
637*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
638*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer3 =
639*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
640*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer4 =
641*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 4 unoptimizable", constWeightsTensor, weightInfo);
642*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer4 =
643*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 4 unoptimizable", constBiasTensor, biasInfo);
644*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer5 =
645*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 5", constWeightsTensor, weightInfo);
646*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer5 =
647*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 5", constBiasTensor, biasInfo);
648*89c4ff92SAndroid Build Coastguard Worker
649*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
650*89c4ff92SAndroid Build Coastguard Worker "conv1 layer", outputInfo);
651*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
652*89c4ff92SAndroid Build Coastguard Worker "conv2 layer unoptimizable", outputInfo);
653*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
654*89c4ff92SAndroid Build Coastguard Worker "conv3 layer", outputInfo);
655*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv4Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
656*89c4ff92SAndroid Build Coastguard Worker "conv4 layer unoptimizable", outputInfo);
657*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv5Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
658*89c4ff92SAndroid Build Coastguard Worker "conv5 layer", outputInfo);
659*89c4ff92SAndroid Build Coastguard Worker
660*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
661*89c4ff92SAndroid Build Coastguard Worker
662*89c4ff92SAndroid Build Coastguard Worker // Connect the network
663*89c4ff92SAndroid Build Coastguard Worker inputLayer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
664*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
665*89c4ff92SAndroid Build Coastguard Worker biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
666*89c4ff92SAndroid Build Coastguard Worker
667*89c4ff92SAndroid Build Coastguard Worker conv1Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
668*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
669*89c4ff92SAndroid Build Coastguard Worker biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
670*89c4ff92SAndroid Build Coastguard Worker
671*89c4ff92SAndroid Build Coastguard Worker conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
672*89c4ff92SAndroid Build Coastguard Worker weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
673*89c4ff92SAndroid Build Coastguard Worker biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
674*89c4ff92SAndroid Build Coastguard Worker
675*89c4ff92SAndroid Build Coastguard Worker conv3Layer->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(0));
676*89c4ff92SAndroid Build Coastguard Worker weightsLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(1));
677*89c4ff92SAndroid Build Coastguard Worker biasLayer4->GetOutputSlot(0).Connect(conv4Layer->GetInputSlot(2));
678*89c4ff92SAndroid Build Coastguard Worker
679*89c4ff92SAndroid Build Coastguard Worker conv4Layer->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(0));
680*89c4ff92SAndroid Build Coastguard Worker weightsLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(1));
681*89c4ff92SAndroid Build Coastguard Worker biasLayer5->GetOutputSlot(0).Connect(conv5Layer->GetInputSlot(2));
682*89c4ff92SAndroid Build Coastguard Worker
683*89c4ff92SAndroid Build Coastguard Worker conv5Layer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
684*89c4ff92SAndroid Build Coastguard Worker
685*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
686*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
687*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom(conv1Layer, ignoreSlots),
688*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({conv5Layer}),
689*89c4ff92SAndroid Build Coastguard Worker {weightsLayer1,
690*89c4ff92SAndroid Build Coastguard Worker biasLayer1,
691*89c4ff92SAndroid Build Coastguard Worker conv1Layer,
692*89c4ff92SAndroid Build Coastguard Worker weightsLayer2,
693*89c4ff92SAndroid Build Coastguard Worker biasLayer2,
694*89c4ff92SAndroid Build Coastguard Worker conv2Layer,
695*89c4ff92SAndroid Build Coastguard Worker weightsLayer3,
696*89c4ff92SAndroid Build Coastguard Worker biasLayer3,
697*89c4ff92SAndroid Build Coastguard Worker conv3Layer,
698*89c4ff92SAndroid Build Coastguard Worker weightsLayer4,
699*89c4ff92SAndroid Build Coastguard Worker biasLayer4,
700*89c4ff92SAndroid Build Coastguard Worker conv4Layer,
701*89c4ff92SAndroid Build Coastguard Worker weightsLayer5,
702*89c4ff92SAndroid Build Coastguard Worker biasLayer5,
703*89c4ff92SAndroid Build Coastguard Worker conv5Layer});
704*89c4ff92SAndroid Build Coastguard Worker }
705*89c4ff92SAndroid Build Coastguard Worker
706*89c4ff92SAndroid Build Coastguard Worker // Creates a subgraph with some input unoptimizable layers ("unoptimizable" is added to the layer's name),
707*89c4ff92SAndroid Build Coastguard Worker // this is meant to test input slots coming from different layers
BuildPartiallyOptimizableSubgraph2(Graph & graph,LayerNameToLayerMap & layersInGraph)708*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr BuildPartiallyOptimizableSubgraph2(Graph& graph, LayerNameToLayerMap& layersInGraph)
709*89c4ff92SAndroid Build Coastguard Worker {
710*89c4ff92SAndroid Build Coastguard Worker const TensorInfo inputInfo ({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
711*89c4ff92SAndroid Build Coastguard Worker const TensorInfo outputInfo({ 1, 16, 16, 16 }, DataType::QAsymmU8, 1.0f, 0);
712*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo({ 16, 1, 1, 16 }, DataType::QAsymmU8, 0.9f, 0);
713*89c4ff92SAndroid Build Coastguard Worker TensorInfo biasInfo ({ 1, 1, 1, 16 }, DataType::Signed32, 0.9f, 0);
714*89c4ff92SAndroid Build Coastguard Worker
715*89c4ff92SAndroid Build Coastguard Worker weightInfo.SetConstant(true);
716*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetConstant(true);
717*89c4ff92SAndroid Build Coastguard Worker
718*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(64);
719*89c4ff92SAndroid Build Coastguard Worker ConstTensor constWeightsTensor(weightInfo, weightsVector);
720*89c4ff92SAndroid Build Coastguard Worker
721*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector(16);
722*89c4ff92SAndroid Build Coastguard Worker ConstTensor constBiasTensor(biasInfo, biasVector);
723*89c4ff92SAndroid Build Coastguard Worker
724*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolutionDescriptor;
725*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideX = 1;
726*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_StrideY = 1;
727*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_BiasEnabled = true;
728*89c4ff92SAndroid Build Coastguard Worker convolutionDescriptor.m_DataLayout = DataLayout::NHWC;
729*89c4ff92SAndroid Build Coastguard Worker
730*89c4ff92SAndroid Build Coastguard Worker // Construct the graph
731*89c4ff92SAndroid Build Coastguard Worker Layer* const input1Layer = AddInputLayer(graph, "input1 layer", inputInfo, 0);
732*89c4ff92SAndroid Build Coastguard Worker Layer* const input2Layer = AddInputLayer(graph, "input2 layer", inputInfo, 1);
733*89c4ff92SAndroid Build Coastguard Worker
734*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer1 =
735*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 1", constWeightsTensor, weightInfo);
736*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer1 =
737*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 1", constBiasTensor, biasInfo);
738*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer2 =
739*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 2 unoptimizable", constWeightsTensor, weightInfo);
740*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer2 =
741*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 2 unoptimizable", constBiasTensor, biasInfo);
742*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const weightsLayer3 =
743*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Weights Layer 3", constWeightsTensor, weightInfo);
744*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* const biasLayer3 =
745*89c4ff92SAndroid Build Coastguard Worker AddConstantLayer(graph, layersInGraph, "Bias Layer 3", constBiasTensor, biasInfo);
746*89c4ff92SAndroid Build Coastguard Worker
747*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv1Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
748*89c4ff92SAndroid Build Coastguard Worker "conv1 layer", outputInfo);
749*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv2Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
750*89c4ff92SAndroid Build Coastguard Worker "conv2 layer unoptimizable", outputInfo);
751*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* const conv3Layer = AddConvolutionLayer(graph, layersInGraph, convolutionDescriptor,
752*89c4ff92SAndroid Build Coastguard Worker "conv3 layer", outputInfo);
753*89c4ff92SAndroid Build Coastguard Worker AdditionLayer* const addLayer = AddAdditionaLayer(graph, layersInGraph, "add layer", outputInfo);
754*89c4ff92SAndroid Build Coastguard Worker Layer* const outputLayer = AddOutputLayer(graph, "output layer");
755*89c4ff92SAndroid Build Coastguard Worker
756*89c4ff92SAndroid Build Coastguard Worker // Connect the network
757*89c4ff92SAndroid Build Coastguard Worker input1Layer->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(0));
758*89c4ff92SAndroid Build Coastguard Worker weightsLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(1));
759*89c4ff92SAndroid Build Coastguard Worker biasLayer1->GetOutputSlot(0).Connect(conv1Layer->GetInputSlot(2));
760*89c4ff92SAndroid Build Coastguard Worker conv1Layer->GetOutputSlot(0).Connect(addLayer->GetInputSlot(0));
761*89c4ff92SAndroid Build Coastguard Worker
762*89c4ff92SAndroid Build Coastguard Worker input2Layer->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(0));
763*89c4ff92SAndroid Build Coastguard Worker weightsLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(1));
764*89c4ff92SAndroid Build Coastguard Worker biasLayer2->GetOutputSlot(0).Connect(conv2Layer->GetInputSlot(2));
765*89c4ff92SAndroid Build Coastguard Worker conv2Layer->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(0));
766*89c4ff92SAndroid Build Coastguard Worker weightsLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(1));
767*89c4ff92SAndroid Build Coastguard Worker biasLayer3->GetOutputSlot(0).Connect(conv3Layer->GetInputSlot(2));
768*89c4ff92SAndroid Build Coastguard Worker conv3Layer->GetOutputSlot(0).Connect(addLayer->GetInputSlot(1));
769*89c4ff92SAndroid Build Coastguard Worker
770*89c4ff92SAndroid Build Coastguard Worker addLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
771*89c4ff92SAndroid Build Coastguard Worker
772*89c4ff92SAndroid Build Coastguard Worker // Create the subgraph view for the whole network
773*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> ignoreSlots = {1, 2};
774*89c4ff92SAndroid Build Coastguard Worker return CreateSubgraphViewFrom(CreateInputsFrom({conv1Layer,
775*89c4ff92SAndroid Build Coastguard Worker conv2Layer}, ignoreSlots),
776*89c4ff92SAndroid Build Coastguard Worker CreateOutputsFrom({addLayer}),
777*89c4ff92SAndroid Build Coastguard Worker { weightsLayer1,
778*89c4ff92SAndroid Build Coastguard Worker biasLayer1,
779*89c4ff92SAndroid Build Coastguard Worker weightsLayer2,
780*89c4ff92SAndroid Build Coastguard Worker biasLayer2,
781*89c4ff92SAndroid Build Coastguard Worker weightsLayer3,
782*89c4ff92SAndroid Build Coastguard Worker biasLayer3,
783*89c4ff92SAndroid Build Coastguard Worker conv1Layer,
784*89c4ff92SAndroid Build Coastguard Worker conv2Layer,
785*89c4ff92SAndroid Build Coastguard Worker conv3Layer,
786*89c4ff92SAndroid Build Coastguard Worker addLayer });
787*89c4ff92SAndroid Build Coastguard Worker }
788*89c4ff92SAndroid Build Coastguard Worker
789*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only a single unsupported layer (only convolutions are unsupported by the mock backend)
FullyUnsupporteSubgraphTestImpl1()790*89c4ff92SAndroid Build Coastguard Worker void FullyUnsupporteSubgraphTestImpl1()
791*89c4ff92SAndroid Build Coastguard Worker {
792*89c4ff92SAndroid Build Coastguard Worker Graph graph;
793*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
794*89c4ff92SAndroid Build Coastguard Worker
795*89c4ff92SAndroid Build Coastguard Worker // Create an unsupported subgraph
796*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnsupportedSubgraph1(graph, layersInGraph);
797*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
798*89c4ff92SAndroid Build Coastguard Worker
799*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
800*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
801*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
802*89c4ff92SAndroid Build Coastguard Worker
803*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
804*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
805*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 1);
806*89c4ff92SAndroid Build Coastguard Worker
807*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling layer"));
808*89c4ff92SAndroid Build Coastguard Worker
809*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
810*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
811*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
812*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
815*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
816*89c4ff92SAndroid Build Coastguard Worker
817*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly, but no optimization is performed
818*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
819*89c4ff92SAndroid Build Coastguard Worker
820*89c4ff92SAndroid Build Coastguard Worker // =======================================================================
821*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
822*89c4ff92SAndroid Build Coastguard Worker // - No substitutions
823*89c4ff92SAndroid Build Coastguard Worker // - Exactly one failed subgraph, corresponding to the whole original one
824*89c4ff92SAndroid Build Coastguard Worker // - No untouched subgraphs
825*89c4ff92SAndroid Build Coastguard Worker // =======================================================================
826*89c4ff92SAndroid Build Coastguard Worker
827*89c4ff92SAndroid Build Coastguard Worker // -----------------------
828*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
829*89c4ff92SAndroid Build Coastguard Worker // -----------------------
830*89c4ff92SAndroid Build Coastguard Worker
831*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetSubstitutions().empty());
832*89c4ff92SAndroid Build Coastguard Worker
833*89c4ff92SAndroid Build Coastguard Worker // --------------------------
834*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
835*89c4ff92SAndroid Build Coastguard Worker // --------------------------
836*89c4ff92SAndroid Build Coastguard Worker
837*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs();
838*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphs.size() == 1);
839*89c4ff92SAndroid Build Coastguard Worker
840*89c4ff92SAndroid Build Coastguard Worker CheckFailedSubgraph(failedSubgraphs.at(0),
841*89c4ff92SAndroid Build Coastguard Worker { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
842*89c4ff92SAndroid Build Coastguard Worker subgraphInputSlots,
843*89c4ff92SAndroid Build Coastguard Worker subgraphOutputSlots,
844*89c4ff92SAndroid Build Coastguard Worker subgraphLayers);
845*89c4ff92SAndroid Build Coastguard Worker
846*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
847*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
848*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
849*89c4ff92SAndroid Build Coastguard Worker
850*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
851*89c4ff92SAndroid Build Coastguard Worker }
852*89c4ff92SAndroid Build Coastguard Worker
853*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only unsupported layers (only convolutions are unsupported by the mock backend)
FullyUnsupporteSubgraphTestImpl2()854*89c4ff92SAndroid Build Coastguard Worker void FullyUnsupporteSubgraphTestImpl2()
855*89c4ff92SAndroid Build Coastguard Worker {
856*89c4ff92SAndroid Build Coastguard Worker Graph graph;
857*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
858*89c4ff92SAndroid Build Coastguard Worker
859*89c4ff92SAndroid Build Coastguard Worker // Create an unsupported subgraph
860*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnsupportedSubgraph2(graph, layersInGraph);
861*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
862*89c4ff92SAndroid Build Coastguard Worker
863*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
864*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
865*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
866*89c4ff92SAndroid Build Coastguard Worker
867*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
868*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
869*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 3);
870*89c4ff92SAndroid Build Coastguard Worker
871*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling1 layer"));
872*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling2 layer"));
873*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling3 layer"));
874*89c4ff92SAndroid Build Coastguard Worker
875*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
876*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
877*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
878*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
879*89c4ff92SAndroid Build Coastguard Worker
880*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
881*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
882*89c4ff92SAndroid Build Coastguard Worker
883*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly, but no optimization is performed
884*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
885*89c4ff92SAndroid Build Coastguard Worker
886*89c4ff92SAndroid Build Coastguard Worker // =======================================================================
887*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
888*89c4ff92SAndroid Build Coastguard Worker // - No substitutions
889*89c4ff92SAndroid Build Coastguard Worker // - Exactly one failed subgraph, corresponding to the whole original one
890*89c4ff92SAndroid Build Coastguard Worker // - No untouched subgraphs
891*89c4ff92SAndroid Build Coastguard Worker // =======================================================================
892*89c4ff92SAndroid Build Coastguard Worker
893*89c4ff92SAndroid Build Coastguard Worker // -----------------------
894*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
895*89c4ff92SAndroid Build Coastguard Worker // -----------------------
896*89c4ff92SAndroid Build Coastguard Worker
897*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetSubstitutions().empty());
898*89c4ff92SAndroid Build Coastguard Worker
899*89c4ff92SAndroid Build Coastguard Worker // --------------------------
900*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
901*89c4ff92SAndroid Build Coastguard Worker // --------------------------
902*89c4ff92SAndroid Build Coastguard Worker
903*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Subgraphs& failedSubgraphs = optimizationViews.GetFailedSubgraphs();
904*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphs.size() == 1);
905*89c4ff92SAndroid Build Coastguard Worker
906*89c4ff92SAndroid Build Coastguard Worker std::list<IConnectableLayer*> expectedFailedLayers{ layersInGraph.at("pooling1 layer"),
907*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("pooling2 layer"),
908*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("pooling3 layer") };
909*89c4ff92SAndroid Build Coastguard Worker
910*89c4ff92SAndroid Build Coastguard Worker const SubgraphView& failedSubgraph = failedSubgraphs.at(0);
911*89c4ff92SAndroid Build Coastguard Worker
912*89c4ff92SAndroid Build Coastguard Worker CheckFailedSubgraph(failedSubgraph,
913*89c4ff92SAndroid Build Coastguard Worker { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
914*89c4ff92SAndroid Build Coastguard Worker subgraphInputSlots,
915*89c4ff92SAndroid Build Coastguard Worker subgraphOutputSlots,
916*89c4ff92SAndroid Build Coastguard Worker subgraphLayers);
917*89c4ff92SAndroid Build Coastguard Worker
918*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& failedSubgraphLayers = failedSubgraph.GetIConnectableLayers();
919*89c4ff92SAndroid Build Coastguard Worker
920*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(failedSubgraphLayers.front() + 0, expectedFailedLayers.front() + 0);
921*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(failedSubgraphLayers.front() + 1, expectedFailedLayers.front() + 1);
922*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(failedSubgraphLayers.front() + 2, expectedFailedLayers.front() + 2);
923*89c4ff92SAndroid Build Coastguard Worker
924*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
925*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
926*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
927*89c4ff92SAndroid Build Coastguard Worker
928*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
929*89c4ff92SAndroid Build Coastguard Worker }
930*89c4ff92SAndroid Build Coastguard Worker
931*89c4ff92SAndroid Build Coastguard Worker // A simple case with only one layer (convolution) to optimize, supported by the mock backend
FullyOptimizableSubgraphTestImpl1()932*89c4ff92SAndroid Build Coastguard Worker void FullyOptimizableSubgraphTestImpl1()
933*89c4ff92SAndroid Build Coastguard Worker {
934*89c4ff92SAndroid Build Coastguard Worker Graph graph;
935*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
936*89c4ff92SAndroid Build Coastguard Worker
937*89c4ff92SAndroid Build Coastguard Worker // Create a fully optimizable subgraph
938*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyOptimizableSubgraph1(graph, layersInGraph);
939*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
940*89c4ff92SAndroid Build Coastguard Worker
941*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
942*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
943*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
944*89c4ff92SAndroid Build Coastguard Worker
945*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
946*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
947*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 3);
948*89c4ff92SAndroid Build Coastguard Worker
949*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv layer"));
950*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer"));
951*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer"));
952*89c4ff92SAndroid Build Coastguard Worker
953*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
954*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
955*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
956*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
957*89c4ff92SAndroid Build Coastguard Worker
958*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
959*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
960*89c4ff92SAndroid Build Coastguard Worker
961*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
962*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
963*89c4ff92SAndroid Build Coastguard Worker
964*89c4ff92SAndroid Build Coastguard Worker // ===========================================================================================
965*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
966*89c4ff92SAndroid Build Coastguard Worker // - Exactly one substitution, mapping the whole input subgraph to a new replacement subgraph
967*89c4ff92SAndroid Build Coastguard Worker // - No failed subgraphs
968*89c4ff92SAndroid Build Coastguard Worker // - No untouched subgraphs
969*89c4ff92SAndroid Build Coastguard Worker // ===========================================================================================
970*89c4ff92SAndroid Build Coastguard Worker
971*89c4ff92SAndroid Build Coastguard Worker // -----------------------
972*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
973*89c4ff92SAndroid Build Coastguard Worker // -----------------------
974*89c4ff92SAndroid Build Coastguard Worker
975*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
976*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutions.size() == 1);
977*89c4ff92SAndroid Build Coastguard Worker
978*89c4ff92SAndroid Build Coastguard Worker CheckSubstitution(substitutions.at(0),
979*89c4ff92SAndroid Build Coastguard Worker { subgraphInputSlots.size(), subgraphOutputSlots.size(), subgraphLayers.size() },
980*89c4ff92SAndroid Build Coastguard Worker { subgraphInputSlots.size(), subgraphOutputSlots.size(), 1 },
981*89c4ff92SAndroid Build Coastguard Worker subgraphInputSlots,
982*89c4ff92SAndroid Build Coastguard Worker subgraphOutputSlots,
983*89c4ff92SAndroid Build Coastguard Worker subgraphLayers);
984*89c4ff92SAndroid Build Coastguard Worker
985*89c4ff92SAndroid Build Coastguard Worker // --------------------------
986*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
987*89c4ff92SAndroid Build Coastguard Worker // --------------------------
988*89c4ff92SAndroid Build Coastguard Worker
989*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetFailedSubgraphs().empty());
990*89c4ff92SAndroid Build Coastguard Worker
991*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
992*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
993*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
994*89c4ff92SAndroid Build Coastguard Worker
995*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
996*89c4ff92SAndroid Build Coastguard Worker }
997*89c4ff92SAndroid Build Coastguard Worker
998*89c4ff92SAndroid Build Coastguard Worker // A case with five layers (all convolutions) to optimize, all supported by the mock backend
FullyOptimizableSubgraphTestImpl2()999*89c4ff92SAndroid Build Coastguard Worker void FullyOptimizableSubgraphTestImpl2()
1000*89c4ff92SAndroid Build Coastguard Worker {
1001*89c4ff92SAndroid Build Coastguard Worker Graph graph;
1002*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
1003*89c4ff92SAndroid Build Coastguard Worker
1004*89c4ff92SAndroid Build Coastguard Worker // Create a fully optimizable subgraph
1005*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyOptimizableSubgraph2(graph, layersInGraph);
1006*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
1007*89c4ff92SAndroid Build Coastguard Worker
1008*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1009*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1010*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1011*89c4ff92SAndroid Build Coastguard Worker
1012*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
1013*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
1014*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphPtr->GetIConnectableLayers().size() == 15);
1015*89c4ff92SAndroid Build Coastguard Worker
1016*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv1 layer"));
1017*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv2 layer"));
1018*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv3 layer"));
1019*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv4 layer"));
1020*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv5 layer"));
1021*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 1"));
1022*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 2"));
1023*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 3"));
1024*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 4"));
1025*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 5"));
1026*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 1"));
1027*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 2"));
1028*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 3"));
1029*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 4"));
1030*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 5"));
1031*89c4ff92SAndroid Build Coastguard Worker
1032*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
1033*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
1034*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
1035*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
1036*89c4ff92SAndroid Build Coastguard Worker
1037*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
1038*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
1039*89c4ff92SAndroid Build Coastguard Worker
1040*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
1041*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1042*89c4ff92SAndroid Build Coastguard Worker
1043*89c4ff92SAndroid Build Coastguard Worker // ===========================================================================================
1044*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
1045*89c4ff92SAndroid Build Coastguard Worker // - Exactly one substitution, mapping the whole input subgraph to a new replacement subgraph
1046*89c4ff92SAndroid Build Coastguard Worker // - No failed subgraphs
1047*89c4ff92SAndroid Build Coastguard Worker // - No untouched subgraphs
1048*89c4ff92SAndroid Build Coastguard Worker // ===========================================================================================
1049*89c4ff92SAndroid Build Coastguard Worker
1050*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1051*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
1052*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1053*89c4ff92SAndroid Build Coastguard Worker
1054*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
1055*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutions.size() == 1);
1056*89c4ff92SAndroid Build Coastguard Worker
1057*89c4ff92SAndroid Build Coastguard Worker std::list<IConnectableLayer*> expectedSubstitutableLayers{
1058*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 1"),
1059*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 2"),
1060*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 3"),
1061*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 4"),
1062*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 5"),
1063*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 1"),
1064*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 2"),
1065*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 3"),
1066*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 4"),
1067*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 5"),
1068*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv1 layer"),
1069*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv2 layer"),
1070*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv3 layer"),
1071*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv4 layer"),
1072*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv5 layer")};
1073*89c4ff92SAndroid Build Coastguard Worker
1074*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::SubstitutionPair& substitution = substitutions.at(0);
1075*89c4ff92SAndroid Build Coastguard Worker
1076*89c4ff92SAndroid Build Coastguard Worker CheckSubstitution(
1077*89c4ff92SAndroid Build Coastguard Worker substitution,
1078*89c4ff92SAndroid Build Coastguard Worker {subgraphInputSlots.size(), subgraphOutputSlots.size(),
1079*89c4ff92SAndroid Build Coastguard Worker subgraphLayers.size()},
1080*89c4ff92SAndroid Build Coastguard Worker {subgraphInputSlots.size(), subgraphOutputSlots.size(), 1},
1081*89c4ff92SAndroid Build Coastguard Worker subgraphInputSlots, subgraphOutputSlots, expectedSubstitutableLayers);
1082*89c4ff92SAndroid Build Coastguard Worker
1083*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& substitutableSubgraphLayers =
1084*89c4ff92SAndroid Build Coastguard Worker substitution.m_SubstitutableSubgraph.GetIConnectableLayers();
1085*89c4ff92SAndroid Build Coastguard Worker
1086*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(substitutableSubgraphLayers.front() + 0, expectedSubstitutableLayers.front() + 0);
1087*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(substitutableSubgraphLayers.front() + 1, expectedSubstitutableLayers.front() + 1);
1088*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(substitutableSubgraphLayers.front() + 2, expectedSubstitutableLayers.front() + 2);
1089*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(substitutableSubgraphLayers.front() + 3, expectedSubstitutableLayers.front() + 3);
1090*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(substitutableSubgraphLayers.front() + 4, expectedSubstitutableLayers.front() + 4);
1091*89c4ff92SAndroid Build Coastguard Worker
1092*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1093*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
1094*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1095*89c4ff92SAndroid Build Coastguard Worker
1096*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetFailedSubgraphs().empty());
1097*89c4ff92SAndroid Build Coastguard Worker
1098*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1099*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
1100*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1101*89c4ff92SAndroid Build Coastguard Worker
1102*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
1103*89c4ff92SAndroid Build Coastguard Worker }
1104*89c4ff92SAndroid Build Coastguard Worker
1105*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contaions both supported and unsupported layers
1106*89c4ff92SAndroid Build Coastguard Worker // (but only convolutions are unsupported by the mock backend)
PartiallySupportedSubgraphTestImpl()1107*89c4ff92SAndroid Build Coastguard Worker void PartiallySupportedSubgraphTestImpl()
1108*89c4ff92SAndroid Build Coastguard Worker {
1109*89c4ff92SAndroid Build Coastguard Worker Graph graph;
1110*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
1111*89c4ff92SAndroid Build Coastguard Worker
1112*89c4ff92SAndroid Build Coastguard Worker // Create a fully optimizable subgraph
1113*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallySupportedSubgraph(graph, layersInGraph);
1114*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
1115*89c4ff92SAndroid Build Coastguard Worker
1116*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1117*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1118*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1119*89c4ff92SAndroid Build Coastguard Worker
1120*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
1121*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
1122*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 9);
1123*89c4ff92SAndroid Build Coastguard Worker
1124*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 1"));
1125*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 1"));
1126*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv1 layer"));
1127*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling1 layer"));
1128*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling2 layer"));
1129*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Weights Layer 2"));
1130*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "Bias Layer 2"));
1131*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv2 layer"));
1132*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "pooling3 layer"));
1133*89c4ff92SAndroid Build Coastguard Worker
1134*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
1135*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
1136*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
1137*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
1138*89c4ff92SAndroid Build Coastguard Worker
1139*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
1140*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
1141*89c4ff92SAndroid Build Coastguard Worker
1142*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
1143*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1144*89c4ff92SAndroid Build Coastguard Worker
1145*89c4ff92SAndroid Build Coastguard Worker // ========================================================================
1146*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
1147*89c4ff92SAndroid Build Coastguard Worker // - Exactly two substitution, corresponding to the supported layers
1148*89c4ff92SAndroid Build Coastguard Worker // - Exactly two failed subgraphs, corresponding to the unsupported layers
1149*89c4ff92SAndroid Build Coastguard Worker // - No untouched subgraphs
1150*89c4ff92SAndroid Build Coastguard Worker // ========================================================================
1151*89c4ff92SAndroid Build Coastguard Worker
1152*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1153*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
1154*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1155*89c4ff92SAndroid Build Coastguard Worker
1156*89c4ff92SAndroid Build Coastguard Worker OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
1157*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutions.size() == 2);
1158*89c4ff92SAndroid Build Coastguard Worker // Sort into a consistent order
1159*89c4ff92SAndroid Build Coastguard Worker std::sort(substitutions.begin(), substitutions.end(), [](auto s1, auto s2) {
1160*89c4ff92SAndroid Build Coastguard Worker return strcmp(s1.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName(),
1161*89c4ff92SAndroid Build Coastguard Worker s2.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName()) < 0;
1162*89c4ff92SAndroid Build Coastguard Worker });
1163*89c4ff92SAndroid Build Coastguard Worker
1164*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 3 },
1165*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 3 } };
1166*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 },
1167*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 1 } };
1168*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IInputSlots> expectedSubstitutableInputSlots
1169*89c4ff92SAndroid Build Coastguard Worker {
1170*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1171*89c4ff92SAndroid Build Coastguard Worker {ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlot(0))}),
1172*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1173*89c4ff92SAndroid Build Coastguard Worker {ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer")->GetInputSlot(0))})
1174*89c4ff92SAndroid Build Coastguard Worker };
1175*89c4ff92SAndroid Build Coastguard Worker
1176*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IOutputSlots> expectedSubstitutableOutputSlots
1177*89c4ff92SAndroid Build Coastguard Worker {
1178*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1179*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots())),
1180*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1181*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer")->GetOutputSlots()))
1182*89c4ff92SAndroid Build Coastguard Worker };
1183*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IConnectableLayers> expectedSubstitutableLayers
1184*89c4ff92SAndroid Build Coastguard Worker {
1185*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 1"), layersInGraph.at("Bias Layer 1"), layersInGraph.at("conv1 layer") },
1186*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 2"), layersInGraph.at("Bias Layer 2"), layersInGraph.at("conv2 layer") }
1187*89c4ff92SAndroid Build Coastguard Worker };
1188*89c4ff92SAndroid Build Coastguard Worker
1189*89c4ff92SAndroid Build Coastguard Worker for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++)
1190*89c4ff92SAndroid Build Coastguard Worker {
1191*89c4ff92SAndroid Build Coastguard Worker CheckSubstitution(substitutions.at(substitutionIndex),
1192*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableSubgraphSizes.at(substitutionIndex),
1193*89c4ff92SAndroid Build Coastguard Worker expectedReplacementSubgraphSizes.at(substitutionIndex),
1194*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableInputSlots.at(substitutionIndex),
1195*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableOutputSlots.at(substitutionIndex),
1196*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableLayers.at(substitutionIndex));
1197*89c4ff92SAndroid Build Coastguard Worker }
1198*89c4ff92SAndroid Build Coastguard Worker
1199*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1200*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
1201*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1202*89c4ff92SAndroid Build Coastguard Worker
1203*89c4ff92SAndroid Build Coastguard Worker OptimizationViews::Subgraphs failedSubgraphs = optimizationViews.GetFailedSubgraphs();
1204*89c4ff92SAndroid Build Coastguard Worker CHECK(failedSubgraphs.size() == 2);
1205*89c4ff92SAndroid Build Coastguard Worker // Sort into a consistent order
1206*89c4ff92SAndroid Build Coastguard Worker std::sort(failedSubgraphs.begin(), failedSubgraphs.end(), [](auto s1, auto s2) {
1207*89c4ff92SAndroid Build Coastguard Worker return strcmp(s1.GetIConnectableLayers().front()->GetName(),
1208*89c4ff92SAndroid Build Coastguard Worker s2.GetIConnectableLayers().front()->GetName()) < 0;
1209*89c4ff92SAndroid Build Coastguard Worker });
1210*89c4ff92SAndroid Build Coastguard Worker
1211*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedFailedSubgraphSizes{ { 1, 1, 2 },
1212*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 1 } };
1213*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IInputSlots> expectedFailedInputSlots
1214*89c4ff92SAndroid Build Coastguard Worker {
1215*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1216*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("pooling1 layer")->GetInputSlots())),
1217*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1218*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("pooling3 layer")->GetInputSlots()))
1219*89c4ff92SAndroid Build Coastguard Worker };
1220*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IOutputSlots> expectedFailedOutputSlots
1221*89c4ff92SAndroid Build Coastguard Worker {
1222*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1223*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("pooling2 layer")->GetOutputSlots())),
1224*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1225*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("pooling3 layer")->GetOutputSlots()))
1226*89c4ff92SAndroid Build Coastguard Worker };
1227*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IConnectableLayers> expectedFailedLayers
1228*89c4ff92SAndroid Build Coastguard Worker {
1229*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("pooling1 layer"),
1230*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("pooling2 layer") },
1231*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("pooling3 layer") }
1232*89c4ff92SAndroid Build Coastguard Worker };
1233*89c4ff92SAndroid Build Coastguard Worker
1234*89c4ff92SAndroid Build Coastguard Worker for (size_t failedIndex = 0; failedIndex < failedSubgraphs.size(); failedIndex++)
1235*89c4ff92SAndroid Build Coastguard Worker {
1236*89c4ff92SAndroid Build Coastguard Worker CheckFailedSubgraph(failedSubgraphs.at(failedIndex),
1237*89c4ff92SAndroid Build Coastguard Worker expectedFailedSubgraphSizes.at(failedIndex),
1238*89c4ff92SAndroid Build Coastguard Worker expectedFailedInputSlots.at(failedIndex),
1239*89c4ff92SAndroid Build Coastguard Worker expectedFailedOutputSlots.at(failedIndex),
1240*89c4ff92SAndroid Build Coastguard Worker expectedFailedLayers.at(failedIndex));
1241*89c4ff92SAndroid Build Coastguard Worker }
1242*89c4ff92SAndroid Build Coastguard Worker
1243*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1244*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
1245*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1246*89c4ff92SAndroid Build Coastguard Worker
1247*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetUntouchedSubgraphs().empty());
1248*89c4ff92SAndroid Build Coastguard Worker }
1249*89c4ff92SAndroid Build Coastguard Worker
1250*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains only unoptimizable layers ("unoptimizable" is added to the layer's name)
FullyUnoptimizableSubgraphTestImpl1()1251*89c4ff92SAndroid Build Coastguard Worker void FullyUnoptimizableSubgraphTestImpl1()
1252*89c4ff92SAndroid Build Coastguard Worker {
1253*89c4ff92SAndroid Build Coastguard Worker Graph graph;
1254*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
1255*89c4ff92SAndroid Build Coastguard Worker
1256*89c4ff92SAndroid Build Coastguard Worker // Create a fully optimizable subgraph
1257*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildFullyUnoptimizableSubgraph1(graph, layersInGraph);
1258*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
1259*89c4ff92SAndroid Build Coastguard Worker
1260*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1261*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1262*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1263*89c4ff92SAndroid Build Coastguard Worker
1264*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
1265*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
1266*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 3);
1267*89c4ff92SAndroid Build Coastguard Worker
1268*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv layer unoptimizable"));
1269*89c4ff92SAndroid Build Coastguard Worker
1270*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
1271*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
1272*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
1273*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
1274*89c4ff92SAndroid Build Coastguard Worker
1275*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
1276*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
1277*89c4ff92SAndroid Build Coastguard Worker
1278*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
1279*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1280*89c4ff92SAndroid Build Coastguard Worker
1281*89c4ff92SAndroid Build Coastguard Worker // ============================================================================
1282*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
1283*89c4ff92SAndroid Build Coastguard Worker // - No substitutions
1284*89c4ff92SAndroid Build Coastguard Worker // - No failed subgraphs
1285*89c4ff92SAndroid Build Coastguard Worker // - Exactly one untouched subgraph, corresponding to the whole input subgraph
1286*89c4ff92SAndroid Build Coastguard Worker // ============================================================================
1287*89c4ff92SAndroid Build Coastguard Worker
1288*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1289*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
1290*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1291*89c4ff92SAndroid Build Coastguard Worker
1292*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetSubstitutions().empty());
1293*89c4ff92SAndroid Build Coastguard Worker
1294*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1295*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
1296*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1297*89c4ff92SAndroid Build Coastguard Worker
1298*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetFailedSubgraphs().empty());
1299*89c4ff92SAndroid Build Coastguard Worker
1300*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1301*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
1302*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1303*89c4ff92SAndroid Build Coastguard Worker
1304*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1305*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphs.size() == 1);
1306*89c4ff92SAndroid Build Coastguard Worker
1307*89c4ff92SAndroid Build Coastguard Worker CheckUntouchedSubgraph(untouchedSubgraphs.at(0),
1308*89c4ff92SAndroid Build Coastguard Worker {subgraphInputSlots.size(),
1309*89c4ff92SAndroid Build Coastguard Worker subgraphOutputSlots.size(), subgraphLayers.size()},
1310*89c4ff92SAndroid Build Coastguard Worker subgraphInputSlots, subgraphOutputSlots,
1311*89c4ff92SAndroid Build Coastguard Worker subgraphLayers);
1312*89c4ff92SAndroid Build Coastguard Worker }
1313*89c4ff92SAndroid Build Coastguard Worker
1314*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains some unoptimizable layers ("unoptimizable" is added to the layer's name)
PartiallyOptimizableSubgraphTestImpl1()1315*89c4ff92SAndroid Build Coastguard Worker void PartiallyOptimizableSubgraphTestImpl1()
1316*89c4ff92SAndroid Build Coastguard Worker {
1317*89c4ff92SAndroid Build Coastguard Worker Graph graph;
1318*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
1319*89c4ff92SAndroid Build Coastguard Worker
1320*89c4ff92SAndroid Build Coastguard Worker // Create a fully optimizable subgraph
1321*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph1(graph, layersInGraph);
1322*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
1323*89c4ff92SAndroid Build Coastguard Worker
1324*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1325*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1326*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1327*89c4ff92SAndroid Build Coastguard Worker
1328*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 1);
1329*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
1330*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 15);
1331*89c4ff92SAndroid Build Coastguard Worker
1332*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv1 layer"));
1333*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv2 layer unoptimizable"));
1334*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv3 layer"));
1335*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv4 layer unoptimizable"));
1336*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv5 layer"));
1337*89c4ff92SAndroid Build Coastguard Worker
1338*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
1339*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
1340*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
1341*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
1342*89c4ff92SAndroid Build Coastguard Worker
1343*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
1344*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
1345*89c4ff92SAndroid Build Coastguard Worker
1346*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
1347*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1348*89c4ff92SAndroid Build Coastguard Worker
1349*89c4ff92SAndroid Build Coastguard Worker // ===============================================================================
1350*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
1351*89c4ff92SAndroid Build Coastguard Worker // - Exactly three substitutions, corresponding to the optimizable layers
1352*89c4ff92SAndroid Build Coastguard Worker // - No failed subgraphs
1353*89c4ff92SAndroid Build Coastguard Worker // - Exactly two untouched subgraphs, corresponding to the non-optimizable layers
1354*89c4ff92SAndroid Build Coastguard Worker // ===============================================================================
1355*89c4ff92SAndroid Build Coastguard Worker
1356*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1357*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
1358*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1359*89c4ff92SAndroid Build Coastguard Worker
1360*89c4ff92SAndroid Build Coastguard Worker OptimizationViews::Substitutions substitutions = optimizationViews.GetSubstitutions();
1361*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutions.size() == 3);
1362*89c4ff92SAndroid Build Coastguard Worker // Sort into a consistent order
1363*89c4ff92SAndroid Build Coastguard Worker std::sort(substitutions.begin(), substitutions.end(),
1364*89c4ff92SAndroid Build Coastguard Worker [](auto s1, auto s2)
1365*89c4ff92SAndroid Build Coastguard Worker { return strcmp(s1.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName(),
1366*89c4ff92SAndroid Build Coastguard Worker s2.m_SubstitutableSubgraph.GetIConnectableLayers().front()->GetName()) < 0; });
1367*89c4ff92SAndroid Build Coastguard Worker
1368*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedSubstitutableSubgraphSizes{ { 1, 1, 3 },
1369*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 3 },
1370*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 3 } };
1371*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedReplacementSubgraphSizes{ { 1, 1, 1 },
1372*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 1 },
1373*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 1 } };
1374*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IInputSlots> expectedSubstitutableInputSlots
1375*89c4ff92SAndroid Build Coastguard Worker {
1376*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1377*89c4ff92SAndroid Build Coastguard Worker {ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlot(0))}),
1378*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1379*89c4ff92SAndroid Build Coastguard Worker {ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlot(0))}),
1380*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>(
1381*89c4ff92SAndroid Build Coastguard Worker {ConvertReferenceTypeToPointerType(layersInGraph.at("conv5 layer")->GetInputSlot(0))})
1382*89c4ff92SAndroid Build Coastguard Worker };
1383*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IOutputSlots> expectedSubstitutableOutputSlots
1384*89c4ff92SAndroid Build Coastguard Worker {
1385*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1386*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetOutputSlots())),
1387*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1388*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetOutputSlots())),
1389*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1390*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv5 layer")->GetOutputSlots()))
1391*89c4ff92SAndroid Build Coastguard Worker };
1392*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IConnectableLayers> expectedSubstitutableLayers
1393*89c4ff92SAndroid Build Coastguard Worker {
1394*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 1"), layersInGraph.at("Bias Layer 1"), layersInGraph.at("conv1 layer") },
1395*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 3"), layersInGraph.at("Bias Layer 3"), layersInGraph.at("conv3 layer") },
1396*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 5"), layersInGraph.at("Bias Layer 5"), layersInGraph.at("conv5 layer") }
1397*89c4ff92SAndroid Build Coastguard Worker };
1398*89c4ff92SAndroid Build Coastguard Worker
1399*89c4ff92SAndroid Build Coastguard Worker for (size_t substitutionIndex = 0; substitutionIndex < substitutions.size(); substitutionIndex++)
1400*89c4ff92SAndroid Build Coastguard Worker {
1401*89c4ff92SAndroid Build Coastguard Worker CheckSubstitution(substitutions.at(substitutionIndex),
1402*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableSubgraphSizes.at(substitutionIndex),
1403*89c4ff92SAndroid Build Coastguard Worker expectedReplacementSubgraphSizes.at(substitutionIndex),
1404*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableInputSlots.at(substitutionIndex),
1405*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableOutputSlots.at(substitutionIndex),
1406*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableLayers.at(substitutionIndex));
1407*89c4ff92SAndroid Build Coastguard Worker }
1408*89c4ff92SAndroid Build Coastguard Worker
1409*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1410*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
1411*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1412*89c4ff92SAndroid Build Coastguard Worker
1413*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetFailedSubgraphs().empty());
1414*89c4ff92SAndroid Build Coastguard Worker
1415*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1416*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
1417*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1418*89c4ff92SAndroid Build Coastguard Worker
1419*89c4ff92SAndroid Build Coastguard Worker OptimizationViews::Subgraphs untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1420*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphs.size() == 2);
1421*89c4ff92SAndroid Build Coastguard Worker // Sort into a consistent order
1422*89c4ff92SAndroid Build Coastguard Worker std::sort(untouchedSubgraphs.begin(), untouchedSubgraphs.end(), [](auto s1, auto s2) {
1423*89c4ff92SAndroid Build Coastguard Worker return strcmp(s1.GetIConnectableLayers().front()->GetName(),
1424*89c4ff92SAndroid Build Coastguard Worker s2.GetIConnectableLayers().front()->GetName()) < 0;
1425*89c4ff92SAndroid Build Coastguard Worker });
1426*89c4ff92SAndroid Build Coastguard Worker
1427*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 3 },
1428*89c4ff92SAndroid Build Coastguard Worker { 1, 1, 3 } };
1429*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IInputSlots> expectedUntouchedInputSlots{
1430*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot,
1431*89c4ff92SAndroid Build Coastguard Worker IInputSlot>({ConvertReferenceTypeToPointerType(
1432*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv2 layer unoptimizable")->GetInputSlot(0))}),
1433*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot,
1434*89c4ff92SAndroid Build Coastguard Worker IInputSlot>({ConvertReferenceTypeToPointerType(
1435*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv4 layer unoptimizable")->GetInputSlot(0))})};
1436*89c4ff92SAndroid Build Coastguard Worker
1437*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IOutputSlots> expectedUntouchedOutputSlots
1438*89c4ff92SAndroid Build Coastguard Worker {
1439*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1440*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer unoptimizable")->GetOutputSlots())),
1441*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1442*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv4 layer unoptimizable")->GetOutputSlots()))
1443*89c4ff92SAndroid Build Coastguard Worker };
1444*89c4ff92SAndroid Build Coastguard Worker
1445*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IConnectableLayers> expectedUntouchedLayers
1446*89c4ff92SAndroid Build Coastguard Worker {
1447*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 2 unoptimizable"),
1448*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 2 unoptimizable"),
1449*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv2 layer unoptimizable") },
1450*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("Weights Layer 4 unoptimizable"),
1451*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 4 unoptimizable"),
1452*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv4 layer unoptimizable") }
1453*89c4ff92SAndroid Build Coastguard Worker };
1454*89c4ff92SAndroid Build Coastguard Worker
1455*89c4ff92SAndroid Build Coastguard Worker for (size_t untouchedIndex = 0; untouchedIndex < untouchedSubgraphs.size(); untouchedIndex++)
1456*89c4ff92SAndroid Build Coastguard Worker {
1457*89c4ff92SAndroid Build Coastguard Worker CheckUntouchedSubgraph(untouchedSubgraphs.at(untouchedIndex),
1458*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedSubgraphSizes.at(untouchedIndex),
1459*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedInputSlots.at(untouchedIndex),
1460*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedOutputSlots.at(untouchedIndex),
1461*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedLayers.at(untouchedIndex));
1462*89c4ff92SAndroid Build Coastguard Worker }
1463*89c4ff92SAndroid Build Coastguard Worker }
1464*89c4ff92SAndroid Build Coastguard Worker
1465*89c4ff92SAndroid Build Coastguard Worker // The input subgraph contains some unoptimizable layers ("unoptimizable" is added to the layer's name),
1466*89c4ff92SAndroid Build Coastguard Worker // this is meant to test input slots coming from different layers
PartiallyOptimizableSubgraphTestImpl2()1467*89c4ff92SAndroid Build Coastguard Worker void PartiallyOptimizableSubgraphTestImpl2()
1468*89c4ff92SAndroid Build Coastguard Worker {
1469*89c4ff92SAndroid Build Coastguard Worker Graph graph;
1470*89c4ff92SAndroid Build Coastguard Worker LayerNameToLayerMap layersInGraph;
1471*89c4ff92SAndroid Build Coastguard Worker
1472*89c4ff92SAndroid Build Coastguard Worker // Create a partially optimizable subgraph
1473*89c4ff92SAndroid Build Coastguard Worker SubgraphView::SubgraphViewPtr subgraphPtr = BuildPartiallyOptimizableSubgraph2(graph, layersInGraph);
1474*89c4ff92SAndroid Build Coastguard Worker CHECK((subgraphPtr != nullptr));
1475*89c4ff92SAndroid Build Coastguard Worker
1476*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IInputSlots& subgraphInputSlots = subgraphPtr->GetIInputSlots();
1477*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IOutputSlots& subgraphOutputSlots = subgraphPtr->GetIOutputSlots();
1478*89c4ff92SAndroid Build Coastguard Worker const SubgraphView::IConnectableLayers& subgraphLayers = subgraphPtr->GetIConnectableLayers();
1479*89c4ff92SAndroid Build Coastguard Worker
1480*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphInputSlots.size() == 2);
1481*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphOutputSlots.size() == 1);
1482*89c4ff92SAndroid Build Coastguard Worker CHECK(subgraphLayers.size() == 10);
1483*89c4ff92SAndroid Build Coastguard Worker
1484*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv1 layer"));
1485*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv2 layer unoptimizable"));
1486*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "conv3 layer"));
1487*89c4ff92SAndroid Build Coastguard Worker CHECK(Contains(layersInGraph, "add layer"));
1488*89c4ff92SAndroid Build Coastguard Worker
1489*89c4ff92SAndroid Build Coastguard Worker // Create a mock backend object
1490*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser initialiser; // Register the Mock Backend
1491*89c4ff92SAndroid Build Coastguard Worker auto backendObjPtr = CreateBackendObject(MockBackendId());
1492*89c4ff92SAndroid Build Coastguard Worker CHECK((backendObjPtr != nullptr));
1493*89c4ff92SAndroid Build Coastguard Worker
1494*89c4ff92SAndroid Build Coastguard Worker // Optimize the subgraph
1495*89c4ff92SAndroid Build Coastguard Worker OptimizationViews optimizationViews;
1496*89c4ff92SAndroid Build Coastguard Worker
1497*89c4ff92SAndroid Build Coastguard Worker // Check that the optimization is carried out correctly
1498*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(optimizationViews = backendObjPtr->OptimizeSubgraphView(*subgraphPtr));
1499*89c4ff92SAndroid Build Coastguard Worker
1500*89c4ff92SAndroid Build Coastguard Worker // ==============================================================================
1501*89c4ff92SAndroid Build Coastguard Worker // The expected results are:
1502*89c4ff92SAndroid Build Coastguard Worker // - Exactly one substitution, corresponding to the optimizable layers
1503*89c4ff92SAndroid Build Coastguard Worker // - No failed subgraphs
1504*89c4ff92SAndroid Build Coastguard Worker // - Exactly two untouched subgraphs, corresponding to the non-optimizable layer
1505*89c4ff92SAndroid Build Coastguard Worker // ==============================================================================
1506*89c4ff92SAndroid Build Coastguard Worker
1507*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1508*89c4ff92SAndroid Build Coastguard Worker // Check the substitutions
1509*89c4ff92SAndroid Build Coastguard Worker // -----------------------
1510*89c4ff92SAndroid Build Coastguard Worker
1511*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Substitutions& substitutions = optimizationViews.GetSubstitutions();
1512*89c4ff92SAndroid Build Coastguard Worker CHECK(substitutions.size() == 1);
1513*89c4ff92SAndroid Build Coastguard Worker
1514*89c4ff92SAndroid Build Coastguard Worker ExpectedSubgraphSize expectedSubstitutableSubgraphSizes{ 2, 1, 7 };
1515*89c4ff92SAndroid Build Coastguard Worker ExpectedSubgraphSize expectedReplacementSubgraphSizes{ 2, 1, 1 };
1516*89c4ff92SAndroid Build Coastguard Worker
1517*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IInputSlots expectedSubstitutableInputSlots
1518*89c4ff92SAndroid Build Coastguard Worker {
1519*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>({
1520*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv1 layer")->GetInputSlots()[0])})[0],
1521*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot, IInputSlot>({
1522*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv3 layer")->GetInputSlots()[0])})[0]
1523*89c4ff92SAndroid Build Coastguard Worker };
1524*89c4ff92SAndroid Build Coastguard Worker
1525*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IOutputSlots expectedSubstitutableOutputSlots
1526*89c4ff92SAndroid Build Coastguard Worker {
1527*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1528*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("add layer")->GetOutputSlots()))
1529*89c4ff92SAndroid Build Coastguard Worker };
1530*89c4ff92SAndroid Build Coastguard Worker
1531*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IConnectableLayers expectedSubstitutableLayers
1532*89c4ff92SAndroid Build Coastguard Worker {
1533*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 1"),
1534*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Weights Layer 3"),
1535*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 1"),
1536*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 3"),
1537*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv1 layer"),
1538*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv3 layer"),
1539*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("add layer")
1540*89c4ff92SAndroid Build Coastguard Worker };
1541*89c4ff92SAndroid Build Coastguard Worker
1542*89c4ff92SAndroid Build Coastguard Worker CheckSubstitution(substitutions[0],
1543*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableSubgraphSizes,
1544*89c4ff92SAndroid Build Coastguard Worker expectedReplacementSubgraphSizes,
1545*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableInputSlots,
1546*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableOutputSlots,
1547*89c4ff92SAndroid Build Coastguard Worker expectedSubstitutableLayers);
1548*89c4ff92SAndroid Build Coastguard Worker
1549*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1550*89c4ff92SAndroid Build Coastguard Worker // Check the failed subgraphs
1551*89c4ff92SAndroid Build Coastguard Worker // --------------------------
1552*89c4ff92SAndroid Build Coastguard Worker
1553*89c4ff92SAndroid Build Coastguard Worker CHECK(optimizationViews.GetFailedSubgraphs().empty());
1554*89c4ff92SAndroid Build Coastguard Worker
1555*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1556*89c4ff92SAndroid Build Coastguard Worker // Check the untouched subgraphs
1557*89c4ff92SAndroid Build Coastguard Worker // -----------------------------
1558*89c4ff92SAndroid Build Coastguard Worker
1559*89c4ff92SAndroid Build Coastguard Worker const OptimizationViews::Subgraphs& untouchedSubgraphs = optimizationViews.GetUntouchedSubgraphs();
1560*89c4ff92SAndroid Build Coastguard Worker CHECK(untouchedSubgraphs.size() == 1);
1561*89c4ff92SAndroid Build Coastguard Worker
1562*89c4ff92SAndroid Build Coastguard Worker std::vector<ExpectedSubgraphSize> expectedUntouchedSubgraphSizes{ { 1, 1, 3 } };
1563*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IInputSlots> expectedUntouchedInputSlots
1564*89c4ff92SAndroid Build Coastguard Worker {
1565*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<InputSlot,
1566*89c4ff92SAndroid Build Coastguard Worker IInputSlot>({ConvertReferenceTypeToPointerType(
1567*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("conv2 layer unoptimizable")->GetInputSlot(0))})};
1568*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IOutputSlots> expectedUntouchedOutputSlots
1569*89c4ff92SAndroid Build Coastguard Worker {
1570*89c4ff92SAndroid Build Coastguard Worker ConvertSlotsToISlots<OutputSlot, IOutputSlot>(
1571*89c4ff92SAndroid Build Coastguard Worker ConvertReferenceTypeToPointerType(layersInGraph.at("conv2 layer unoptimizable")->GetOutputSlots()))
1572*89c4ff92SAndroid Build Coastguard Worker };
1573*89c4ff92SAndroid Build Coastguard Worker std::vector<SubgraphView::IConnectableLayers> expectedUntouchedLayers
1574*89c4ff92SAndroid Build Coastguard Worker {
1575*89c4ff92SAndroid Build Coastguard Worker { layersInGraph.at("conv2 layer unoptimizable"), layersInGraph.at("Weights Layer 2 unoptimizable"),
1576*89c4ff92SAndroid Build Coastguard Worker layersInGraph.at("Bias Layer 2 unoptimizable") }
1577*89c4ff92SAndroid Build Coastguard Worker };
1578*89c4ff92SAndroid Build Coastguard Worker
1579*89c4ff92SAndroid Build Coastguard Worker for (size_t untouchedIndex = 0; untouchedIndex < untouchedSubgraphs.size(); untouchedIndex++)
1580*89c4ff92SAndroid Build Coastguard Worker {
1581*89c4ff92SAndroid Build Coastguard Worker CheckUntouchedSubgraph(untouchedSubgraphs.at(untouchedIndex),
1582*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedSubgraphSizes.at(untouchedIndex),
1583*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedInputSlots.at(untouchedIndex),
1584*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedOutputSlots.at(untouchedIndex),
1585*89c4ff92SAndroid Build Coastguard Worker expectedUntouchedLayers.at(untouchedIndex));
1586*89c4ff92SAndroid Build Coastguard Worker }
1587*89c4ff92SAndroid Build Coastguard Worker }
1588*89c4ff92SAndroid Build Coastguard Worker
1589*89c4ff92SAndroid Build Coastguard Worker } // Anonymous namespace
1590*89c4ff92SAndroid Build Coastguard Worker
1591*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OptimizeSubGraph")
1592*89c4ff92SAndroid Build Coastguard Worker {
1593*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnsupportedSubgraph1") { FullyUnsupporteSubgraphTestImpl1(); }
1594*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnsupportedSubgraph2") { FullyUnsupporteSubgraphTestImpl2(); }
1595*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyOptimizableSubgraph1") { FullyOptimizableSubgraphTestImpl1(); }
1596*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyOptimizableSubgraph2") { FullyOptimizableSubgraphTestImpl2(); }
1597*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallySupportedSubgraph") { PartiallySupportedSubgraphTestImpl(); }
1598*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("FullyUnoptimizableSubgraph") { FullyUnoptimizableSubgraphTestImpl1(); }
1599*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallyOptimizableSubgraph1") { PartiallyOptimizableSubgraphTestImpl1(); }
1600*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("PartiallyOptimizableSubgraph2") { PartiallyOptimizableSubgraphTestImpl2(); }
1601*89c4ff92SAndroid Build Coastguard Worker
1602*89c4ff92SAndroid Build Coastguard Worker }
1603