xref: /aosp_15_r20/external/armnn/src/armnn/layers/ElementwiseBaseLayer.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <Layer.hpp>
9 
10 namespace armnn
11 {
12 
13 /// NOTE: this is an abstract class to encapsulate the element wise operations, it does not implement:
14 /// std::unique_ptr<IWorkload> Layer::CreateWorkload(const IWorkloadFactory& factory) const = 0;
15 /// Layer* Clone(Graph& graph) const = 0;
16 class ElementwiseBaseLayer : public Layer
17 {
18 public:
19     /// Check if the input tensor shape(s)
20     /// will lead to a valid configuration of the element wise operation.
21     /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
22     void ValidateTensorShapesFromInputs() override;
23 
24     /// By default returns inputShapes if the number of inputs are equal to number of outputs,
25     /// otherwise infers the output shapes from given input shapes and layer properties.
26     /// @param [in] inputShapes The input shapes layer has.
27     /// @return A vector to the inferred output shape.
28     std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
29 
30     void ExecuteStrategy(IStrategy& strategy) const override;
31 
32 protected:
33     /// @param numInputSlots The number of input slots for the layer.
34     /// @param numOutputSlots The number of output slots for the layer.
35     /// @param type The layer type.
36     /// @param name Optional name for the layer.
37     ElementwiseBaseLayer(unsigned int numInputSlots, unsigned int numOutputSlots, LayerType type, const char* name);
38 
39     /// Default destructor
40     ~ElementwiseBaseLayer() = default;
41 };
42 
43 } // namespace
44