xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/GraphUtils.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "GraphUtils.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
GraphHasNamedLayer(const armnn::Graph & graph,const std::string & name)10*89c4ff92SAndroid Build Coastguard Worker bool GraphHasNamedLayer(const armnn::Graph& graph, const std::string& name)
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : graph)
13*89c4ff92SAndroid Build Coastguard Worker     {
14*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetName() == name)
15*89c4ff92SAndroid Build Coastguard Worker         {
16*89c4ff92SAndroid Build Coastguard Worker             return true;
17*89c4ff92SAndroid Build Coastguard Worker         }
18*89c4ff92SAndroid Build Coastguard Worker     }
19*89c4ff92SAndroid Build Coastguard Worker     return false;
20*89c4ff92SAndroid Build Coastguard Worker }
21*89c4ff92SAndroid Build Coastguard Worker 
GetFirstLayerWithName(armnn::Graph & graph,const std::string & name)22*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* GetFirstLayerWithName(armnn::Graph& graph, const std::string& name)
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker     for (auto&& layer : graph)
25*89c4ff92SAndroid Build Coastguard Worker     {
26*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetNameStr() == name)
27*89c4ff92SAndroid Build Coastguard Worker         {
28*89c4ff92SAndroid Build Coastguard Worker             return layer;
29*89c4ff92SAndroid Build Coastguard Worker         }
30*89c4ff92SAndroid Build Coastguard Worker     }
31*89c4ff92SAndroid Build Coastguard Worker     return nullptr;
32*89c4ff92SAndroid Build Coastguard Worker }
33*89c4ff92SAndroid Build Coastguard Worker 
CheckNumberOfInputSlot(armnn::Layer * layer,unsigned int num)34*89c4ff92SAndroid Build Coastguard Worker bool CheckNumberOfInputSlot(armnn::Layer* layer, unsigned int num)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker     return layer->GetNumInputSlots() == num;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker 
CheckNumberOfOutputSlot(armnn::Layer * layer,unsigned int num)39*89c4ff92SAndroid Build Coastguard Worker bool CheckNumberOfOutputSlot(armnn::Layer* layer, unsigned int num)
40*89c4ff92SAndroid Build Coastguard Worker {
41*89c4ff92SAndroid Build Coastguard Worker     return layer->GetNumOutputSlots() == num;
42*89c4ff92SAndroid Build Coastguard Worker }
43*89c4ff92SAndroid Build Coastguard Worker 
IsConnected(armnn::Layer * srcLayer,armnn::Layer * destLayer,unsigned int srcSlot,unsigned int destSlot,const armnn::TensorInfo & expectedTensorInfo)44*89c4ff92SAndroid Build Coastguard Worker bool IsConnected(armnn::Layer* srcLayer, armnn::Layer* destLayer,
45*89c4ff92SAndroid Build Coastguard Worker                  unsigned int srcSlot, unsigned int destSlot,
46*89c4ff92SAndroid Build Coastguard Worker                  const armnn::TensorInfo& expectedTensorInfo)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker     const armnn::IOutputSlot& outputSlot = srcLayer->GetOutputSlot(srcSlot);
49*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
50*89c4ff92SAndroid Build Coastguard Worker     if (expectedTensorInfo != tensorInfo)
51*89c4ff92SAndroid Build Coastguard Worker     {
52*89c4ff92SAndroid Build Coastguard Worker         return false;
53*89c4ff92SAndroid Build Coastguard Worker     }
54*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numConnections = outputSlot.GetNumConnections();
55*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int c = 0; c < numConnections; ++c)
56*89c4ff92SAndroid Build Coastguard Worker     {
57*89c4ff92SAndroid Build Coastguard Worker         auto inputSlot = armnn::PolymorphicDowncast<const armnn::InputSlot*>(outputSlot.GetConnection(c));
58*89c4ff92SAndroid Build Coastguard Worker         if (inputSlot->GetOwningLayer().GetNameStr() == destLayer->GetNameStr() &&
59*89c4ff92SAndroid Build Coastguard Worker             inputSlot->GetSlotIndex() == destSlot)
60*89c4ff92SAndroid Build Coastguard Worker         {
61*89c4ff92SAndroid Build Coastguard Worker             return true;
62*89c4ff92SAndroid Build Coastguard Worker         }
63*89c4ff92SAndroid Build Coastguard Worker     }
64*89c4ff92SAndroid Build Coastguard Worker     return false;
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker /// Checks that first comes before second in the order.
CheckOrder(const armnn::Graph & graph,const armnn::Layer * first,const armnn::Layer * second)68*89c4ff92SAndroid Build Coastguard Worker bool CheckOrder(const armnn::Graph& graph, const armnn::Layer* first, const armnn::Layer* second)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     graph.Print();
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     const auto& order = graph.TopologicalSort();
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     auto firstPos = std::find(order.begin(), order.end(), first);
75*89c4ff92SAndroid Build Coastguard Worker     auto secondPos = std::find(firstPos, order.end(), second);
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     return (secondPos != order.end());
78*89c4ff92SAndroid Build Coastguard Worker }
79