xref: /aosp_15_r20/external/armnn/src/armnn/layers/ElementwiseBaseLayer.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <Layer.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker namespace armnn
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker /// NOTE: this is an abstract class to encapsulate the element wise operations, it does not implement:
14*89c4ff92SAndroid Build Coastguard Worker /// std::unique_ptr<IWorkload> Layer::CreateWorkload(const IWorkloadFactory& factory) const = 0;
15*89c4ff92SAndroid Build Coastguard Worker /// Layer* Clone(Graph& graph) const = 0;
16*89c4ff92SAndroid Build Coastguard Worker class ElementwiseBaseLayer : public Layer
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker public:
19*89c4ff92SAndroid Build Coastguard Worker     /// Check if the input tensor shape(s)
20*89c4ff92SAndroid Build Coastguard Worker     /// will lead to a valid configuration of the element wise operation.
21*89c4ff92SAndroid Build Coastguard Worker     /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
22*89c4ff92SAndroid Build Coastguard Worker     void ValidateTensorShapesFromInputs() override;
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     /// By default returns inputShapes if the number of inputs are equal to number of outputs,
25*89c4ff92SAndroid Build Coastguard Worker     /// otherwise infers the output shapes from given input shapes and layer properties.
26*89c4ff92SAndroid Build Coastguard Worker     /// @param [in] inputShapes The input shapes layer has.
27*89c4ff92SAndroid Build Coastguard Worker     /// @return A vector to the inferred output shape.
28*89c4ff92SAndroid Build Coastguard Worker     std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker     void ExecuteStrategy(IStrategy& strategy) const override;
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker protected:
33*89c4ff92SAndroid Build Coastguard Worker     /// @param numInputSlots The number of input slots for the layer.
34*89c4ff92SAndroid Build Coastguard Worker     /// @param numOutputSlots The number of output slots for the layer.
35*89c4ff92SAndroid Build Coastguard Worker     /// @param type The layer type.
36*89c4ff92SAndroid Build Coastguard Worker     /// @param name Optional name for the layer.
37*89c4ff92SAndroid Build Coastguard Worker     ElementwiseBaseLayer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name);
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker     /// Default destructor
40*89c4ff92SAndroid Build Coastguard Worker     ~ElementwiseBaseLayer() = default;
41*89c4ff92SAndroid Build Coastguard Worker };
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker } // namespace
44