1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021,2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker // NOTE: the MultiLayerFacade class is a utility class which makes a chain 9*89c4ff92SAndroid Build Coastguard Worker // of operators look like a single IConnectableLayer with the first 10*89c4ff92SAndroid Build Coastguard Worker // layer in the chain supplying the input slots and the last supplying 11*89c4ff92SAndroid Build Coastguard Worker // the output slots. It enables us, for example, to simulate a 12*89c4ff92SAndroid Build Coastguard Worker // Tensorflow Lite FloorDiv operator by chaining a Div layer followed 13*89c4ff92SAndroid Build Coastguard Worker // by a Floor layer and pass them as a single unit to the code that 14*89c4ff92SAndroid Build Coastguard Worker // connects up the graph as the delegate proceeds to build up the 15*89c4ff92SAndroid Build Coastguard Worker // Arm NN subgraphs. 16*89c4ff92SAndroid Build Coastguard Worker // 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker #include <common/include/ProfilingGuid.hpp> 19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp> 20*89c4ff92SAndroid Build Coastguard Worker 21*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker class MultiLayerFacade : public armnn::IConnectableLayer 25*89c4ff92SAndroid Build Coastguard Worker { 26*89c4ff92SAndroid Build Coastguard Worker public: MultiLayerFacade()27*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade() : 28*89c4ff92SAndroid Build Coastguard Worker m_FirstLayer(nullptr), m_LastLayer(nullptr) {} 29*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade(armnn::IConnectableLayer * firstLayer,armnn::IConnectableLayer * lastLayer)30*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade(armnn::IConnectableLayer* firstLayer, armnn::IConnectableLayer* lastLayer) : 31*89c4ff92SAndroid Build Coastguard Worker m_FirstLayer(firstLayer), m_LastLayer(lastLayer) {} 32*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade(const MultiLayerFacade & obj)33*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade(const MultiLayerFacade& obj) : 34*89c4ff92SAndroid Build Coastguard Worker m_FirstLayer(obj.m_FirstLayer), m_LastLayer(obj.m_LastLayer) {} 35*89c4ff92SAndroid Build Coastguard Worker ~MultiLayerFacade()36*89c4ff92SAndroid Build Coastguard Worker ~MultiLayerFacade() {} // we don't own the pointers 37*89c4ff92SAndroid Build Coastguard Worker operator =(const MultiLayerFacade & obj)38*89c4ff92SAndroid Build Coastguard Worker MultiLayerFacade& operator=(const MultiLayerFacade& obj) 39*89c4ff92SAndroid Build Coastguard Worker { 40*89c4ff92SAndroid Build Coastguard Worker m_FirstLayer = obj.m_FirstLayer; 41*89c4ff92SAndroid Build Coastguard Worker m_LastLayer = obj.m_LastLayer; 42*89c4ff92SAndroid Build Coastguard Worker return *this; 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker AssignValues(armnn::IConnectableLayer * firstLayer,armnn::IConnectableLayer * lastLayer)45*89c4ff92SAndroid Build Coastguard Worker void AssignValues(armnn::IConnectableLayer* firstLayer, armnn::IConnectableLayer* lastLayer) 46*89c4ff92SAndroid Build Coastguard Worker { 47*89c4ff92SAndroid Build Coastguard Worker m_FirstLayer = firstLayer; 48*89c4ff92SAndroid Build Coastguard Worker m_LastLayer = lastLayer; 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker GetName() const51*89c4ff92SAndroid Build Coastguard Worker virtual const char* GetName() const override 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetName(); 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker GetNumInputSlots() const56*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetNumInputSlots() const override 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetNumInputSlots(); 59*89c4ff92SAndroid Build Coastguard Worker } 60*89c4ff92SAndroid Build Coastguard Worker GetNumOutputSlots() const61*89c4ff92SAndroid Build Coastguard Worker virtual unsigned int GetNumOutputSlots() const override 62*89c4ff92SAndroid Build Coastguard Worker { 63*89c4ff92SAndroid Build Coastguard Worker return m_LastLayer->GetNumOutputSlots(); 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker GetInputSlot(unsigned int index) const66*89c4ff92SAndroid Build Coastguard Worker virtual const armnn::IInputSlot& GetInputSlot(unsigned int index) const override 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetInputSlot(index); 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker GetInputSlot(unsigned int index)71*89c4ff92SAndroid Build Coastguard Worker virtual armnn::IInputSlot& GetInputSlot(unsigned int index) override 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetInputSlot(index); 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker GetOutputSlot(unsigned int index) const76*89c4ff92SAndroid Build Coastguard Worker virtual const armnn::IOutputSlot& GetOutputSlot(unsigned int index) const override 77*89c4ff92SAndroid Build Coastguard Worker { 78*89c4ff92SAndroid Build Coastguard Worker return m_LastLayer->GetOutputSlot(index); 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker GetOutputSlot(unsigned int index)81*89c4ff92SAndroid Build Coastguard Worker virtual armnn::IOutputSlot& GetOutputSlot(unsigned int index) override 82*89c4ff92SAndroid Build Coastguard Worker { 83*89c4ff92SAndroid Build Coastguard Worker return m_LastLayer->GetOutputSlot(index); 84*89c4ff92SAndroid Build Coastguard Worker } 85*89c4ff92SAndroid Build Coastguard Worker InferOutputShapes(const std::vector<armnn::TensorShape> & inputShapes) const86*89c4ff92SAndroid Build Coastguard Worker virtual std::vector<armnn::TensorShape> InferOutputShapes( 87*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::TensorShape>& inputShapes) const override 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker // NOTE: do not expect this function to be used. Likely that if it is it might need to be overridden 90*89c4ff92SAndroid Build Coastguard Worker // for particular sequences of operators. 91*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->InferOutputShapes(inputShapes); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker GetGuid() const94*89c4ff92SAndroid Build Coastguard Worker virtual LayerGuid GetGuid() const override 95*89c4ff92SAndroid Build Coastguard Worker { 96*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetGuid(); 97*89c4ff92SAndroid Build Coastguard Worker } 98*89c4ff92SAndroid Build Coastguard Worker ExecuteStrategy(armnn::IStrategy & strategy) const99*89c4ff92SAndroid Build Coastguard Worker virtual void ExecuteStrategy(armnn::IStrategy& strategy) const override 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker // Do not expect this function to be used so not providing an implementation 102*89c4ff92SAndroid Build Coastguard Worker // if an implementation is required and the chain contains more than two operators 103*89c4ff92SAndroid Build Coastguard Worker // would have to provide a way to record the intermediate layers so they could be 104*89c4ff92SAndroid Build Coastguard Worker // visited... the same applies to the BackendSelectionHint 105*89c4ff92SAndroid Build Coastguard Worker // below. 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker BackendSelectionHint(armnn::Optional<armnn::BackendId> backend)108*89c4ff92SAndroid Build Coastguard Worker virtual void BackendSelectionHint(armnn::Optional<armnn::BackendId> backend) override 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker // Do not expect this function to be used so not providing an implementation 111*89c4ff92SAndroid Build Coastguard Worker } 112*89c4ff92SAndroid Build Coastguard Worker GetType() const113*89c4ff92SAndroid Build Coastguard Worker virtual armnn::LayerType GetType() const override 114*89c4ff92SAndroid Build Coastguard Worker { 115*89c4ff92SAndroid Build Coastguard Worker return m_FirstLayer->GetType(); 116*89c4ff92SAndroid Build Coastguard Worker } 117*89c4ff92SAndroid Build Coastguard Worker GetParameters() const118*89c4ff92SAndroid Build Coastguard Worker virtual const armnn::BaseDescriptor& GetParameters() const override { return m_NullDescriptor; } 119*89c4ff92SAndroid Build Coastguard Worker SetBackendId(const armnn::BackendId & id)120*89c4ff92SAndroid Build Coastguard Worker void SetBackendId(const armnn::BackendId& id) override {} 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker protected: 123*89c4ff92SAndroid Build Coastguard Worker /// Retrieve the handles to the constant values stored by the layer. 124*89c4ff92SAndroid Build Coastguard Worker /// @return A vector of the constant tensors stored by this layer. GetConstantTensorsByRef()125*89c4ff92SAndroid Build Coastguard Worker ConstantTensors GetConstantTensorsByRef() override { return {}; } GetConstantTensorsByRef() const126*89c4ff92SAndroid Build Coastguard Worker ImmutableConstantTensors GetConstantTensorsByRef() const override { return {}; } 127*89c4ff92SAndroid Build Coastguard Worker 128*89c4ff92SAndroid Build Coastguard Worker private: 129*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* m_FirstLayer; 130*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* m_LastLayer; 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker // to satisfy the GetParameters method need to hand back a NullDescriptor 133*89c4ff92SAndroid Build Coastguard Worker armnn::NullDescriptor m_NullDescriptor; 134*89c4ff92SAndroid Build Coastguard Worker }; 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate 137