1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/OptimizationViews.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker namespace armnn 9*89c4ff92SAndroid Build Coastguard Worker { 10*89c4ff92SAndroid Build Coastguard Worker Validate(const armnn::SubgraphView & originalSubgraph) const11*89c4ff92SAndroid Build Coastguard Workerbool OptimizationViews::Validate(const armnn::SubgraphView& originalSubgraph) const 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker //This needs to verify that: 14*89c4ff92SAndroid Build Coastguard Worker // 1) the sum of m_SuccesfulOptimizations & m_FailedOptimizations & m_UntouchedSubgraphs contains subgraphviews 15*89c4ff92SAndroid Build Coastguard Worker // which cover the entire space of the originalSubgraph. 16*89c4ff92SAndroid Build Coastguard Worker // 2) Each SubstitutionPair contains matching inputs and outputs 17*89c4ff92SAndroid Build Coastguard Worker bool valid = true; 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker // Create a copy of the layer list from the original subgraph and sort it 20*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IConnectableLayers originalLayers = originalSubgraph.GetIConnectableLayers(); 21*89c4ff92SAndroid Build Coastguard Worker originalLayers.sort(); 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker // Create a new list based on the sum of all the subgraphs and sort it 24*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IConnectableLayers countedLayers; 25*89c4ff92SAndroid Build Coastguard Worker for (auto& failed : m_FailedOptimizations) 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker countedLayers.insert(countedLayers.end(), 28*89c4ff92SAndroid Build Coastguard Worker failed.GetIConnectableLayers().begin(), 29*89c4ff92SAndroid Build Coastguard Worker failed.GetIConnectableLayers().end()); 30*89c4ff92SAndroid Build Coastguard Worker } 31*89c4ff92SAndroid Build Coastguard Worker for (auto& untouched : m_UntouchedSubgraphs) 32*89c4ff92SAndroid Build Coastguard Worker { 33*89c4ff92SAndroid Build Coastguard Worker countedLayers.insert(countedLayers.end(), 34*89c4ff92SAndroid Build Coastguard Worker untouched.GetIConnectableLayers().begin(), 35*89c4ff92SAndroid Build Coastguard Worker untouched.GetIConnectableLayers().end()); 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker for (auto& successful : m_SuccesfulOptimizations) 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker countedLayers.insert(countedLayers.end(), 40*89c4ff92SAndroid Build Coastguard Worker successful.m_SubstitutableSubgraph.GetIConnectableLayers().begin(), 41*89c4ff92SAndroid Build Coastguard Worker successful.m_SubstitutableSubgraph.GetIConnectableLayers().end()); 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker countedLayers.sort(); 44*89c4ff92SAndroid Build Coastguard Worker 45*89c4ff92SAndroid Build Coastguard Worker // Compare the two lists to make sure they match 46*89c4ff92SAndroid Build Coastguard Worker valid &= originalLayers.size() == countedLayers.size(); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker auto oIt = originalLayers.begin(); 49*89c4ff92SAndroid Build Coastguard Worker auto cIt = countedLayers.begin(); 50*89c4ff92SAndroid Build Coastguard Worker for (size_t i=0; i < originalLayers.size() && valid; ++i, ++oIt, ++cIt) 51*89c4ff92SAndroid Build Coastguard Worker { 52*89c4ff92SAndroid Build Coastguard Worker valid &= (*oIt == *cIt); 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker // Compare the substitution subgraphs to ensure they are compatible 56*89c4ff92SAndroid Build Coastguard Worker if (valid) 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker for (auto& substitution : m_SuccesfulOptimizations) 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker bool validSubstitution = true; 61*89c4ff92SAndroid Build Coastguard Worker const SubgraphView& replacement = substitution.m_ReplacementSubgraph; 62*89c4ff92SAndroid Build Coastguard Worker const SubgraphView& old = substitution.m_SubstitutableSubgraph; 63*89c4ff92SAndroid Build Coastguard Worker validSubstitution &= replacement.GetIInputSlots().size() == old.GetIInputSlots().size(); 64*89c4ff92SAndroid Build Coastguard Worker validSubstitution &= replacement.GetIOutputSlots().size() == old.GetIOutputSlots().size(); 65*89c4ff92SAndroid Build Coastguard Worker valid &= validSubstitution; 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker return valid; 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn 71