1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "GraphUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
GraphHasNamedLayer(const armnn::Graph & graph,const std::string & name)10*89c4ff92SAndroid Build Coastguard Worker bool GraphHasNamedLayer(const armnn::Graph& graph, const std::string& name)
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : graph)
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker if (layer->GetName() == name)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker return true;
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker }
19*89c4ff92SAndroid Build Coastguard Worker return false;
20*89c4ff92SAndroid Build Coastguard Worker }
21*89c4ff92SAndroid Build Coastguard Worker
GetFirstLayerWithName(armnn::Graph & graph,const std::string & name)22*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* GetFirstLayerWithName(armnn::Graph& graph, const std::string& name)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : graph)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker if (layer->GetNameStr() == name)
27*89c4ff92SAndroid Build Coastguard Worker {
28*89c4ff92SAndroid Build Coastguard Worker return layer;
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker return nullptr;
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker
CheckNumberOfInputSlot(armnn::Layer * layer,unsigned int num)34*89c4ff92SAndroid Build Coastguard Worker bool CheckNumberOfInputSlot(armnn::Layer* layer, unsigned int num)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker return layer->GetNumInputSlots() == num;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker
CheckNumberOfOutputSlot(armnn::Layer * layer,unsigned int num)39*89c4ff92SAndroid Build Coastguard Worker bool CheckNumberOfOutputSlot(armnn::Layer* layer, unsigned int num)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker return layer->GetNumOutputSlots() == num;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker
IsConnected(armnn::Layer * srcLayer,armnn::Layer * destLayer,unsigned int srcSlot,unsigned int destSlot,const armnn::TensorInfo & expectedTensorInfo)44*89c4ff92SAndroid Build Coastguard Worker bool IsConnected(armnn::Layer* srcLayer, armnn::Layer* destLayer,
45*89c4ff92SAndroid Build Coastguard Worker unsigned int srcSlot, unsigned int destSlot,
46*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& expectedTensorInfo)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker const armnn::IOutputSlot& outputSlot = srcLayer->GetOutputSlot(srcSlot);
49*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
50*89c4ff92SAndroid Build Coastguard Worker if (expectedTensorInfo != tensorInfo)
51*89c4ff92SAndroid Build Coastguard Worker {
52*89c4ff92SAndroid Build Coastguard Worker return false;
53*89c4ff92SAndroid Build Coastguard Worker }
54*89c4ff92SAndroid Build Coastguard Worker const unsigned int numConnections = outputSlot.GetNumConnections();
55*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < numConnections; ++c)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker auto inputSlot = armnn::PolymorphicDowncast<const armnn::InputSlot*>(outputSlot.GetConnection(c));
58*89c4ff92SAndroid Build Coastguard Worker if (inputSlot->GetOwningLayer().GetNameStr() == destLayer->GetNameStr() &&
59*89c4ff92SAndroid Build Coastguard Worker inputSlot->GetSlotIndex() == destSlot)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker return true;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker return false;
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker /// Checks that first comes before second in the order.
CheckOrder(const armnn::Graph & graph,const armnn::Layer * first,const armnn::Layer * second)68*89c4ff92SAndroid Build Coastguard Worker bool CheckOrder(const armnn::Graph& graph, const armnn::Layer* first, const armnn::Layer* second)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker graph.Print();
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker const auto& order = graph.TopologicalSort();
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker auto firstPos = std::find(order.begin(), order.end(), first);
75*89c4ff92SAndroid Build Coastguard Worker auto secondPos = std::find(firstPos, order.end(), second);
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker return (secondPos != order.end());
78*89c4ff92SAndroid Build Coastguard Worker }
79